diff --git a/README.md b/README.md index 42d8b5f..e6b76d3 100644 --- a/README.md +++ b/README.md @@ -9,16 +9,35 @@ This package annotates genetic variants with their predicted effect on splicing, SpliceAI source code is provided under the [GPLv3 license](LICENSE). SpliceAI includes several third party packages provided under other open source licenses, please see [NOTICE](NOTICE) for additional details. The trained models used by SpliceAI (located in this package at spliceai/models) are provided under the [CC BY NC 4.0](LICENSE) license for academic and non-commercial use; other use requires a commercial license from Illumina, Inc. ### Installation -The simplest way to install SpliceAI is through pip or conda: + +This release can most easily be used as a docker container: + +```sh +docker pull cmgantwerpen/spliceai_v1.3:latest + +docker run --gpus all cmgantwerpen/spliceai_v1.3:latest spliceai -h +``` + +A container including reference and annotation data is available as well: + + +```sh +docker pull cmgantwerpen/spliceai_v1.3:full +``` +Note that this version has a larger footprint (12Gb). Data is available for Genome Build hg19 and hg38 under /data/ + + + +The simplest way to install (the original version of) SpliceAI is through pip or conda: ```sh pip install spliceai # or conda install -c bioconda spliceai ``` -Alternately, SpliceAI can be installed from the [github repository](https://github.com/Illumina/SpliceAI.git): +Alternately, SpliceAI can be installed from the [github repository](https://github.com/invitae/SpliceAI.git): ```sh -git clone https://github.com/Illumina/SpliceAI.git +git clone https://github.com/invitae/SpliceAI.git cd SpliceAI python setup.py install ``` @@ -42,7 +61,7 @@ Required parameters: - ```-I```: Input VCF with variants of interest. - ```-O```: Output VCF with SpliceAI predictions `ALLELE|SYMBOL|DS_AG|DS_AL|DS_DG|DS_DL|DP_AG|DP_AL|DP_DG|DP_DL` included in the INFO column (see table below for details). Only SNVs and simple INDELs (REF or ALT is a single base) within genes are annotated. Variants in multiple genes have separate predictions for each gene. - ```-R```: Reference genome fasta file. Can be downloaded from [GRCh37/hg19](http://hgdownload.cse.ucsc.edu/goldenPath/hg19/bigZips/hg19.fa.gz) or [GRCh38/hg38](http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz). - - ```-A```: Gene annotation file. Can instead provide `grch37` or `grch38` to use GENCODE V24 canonical annotation files included with the package. To create custom gene annotation files, use `spliceai/annotations/grch37.txt` in repository as template. + - ```-A```: Gene annotation file. Can instead provide `grch37` or `grch38` to use GENCODE V24 canonical annotation files included with the package. To create custom gene annotation files, use `spliceai/annotations/grch37.txt` in repository as template and provide as full path. Optional parameters: - ```-D```: Maximum distance between the variant and gained/lost splice site (default: 50). @@ -50,33 +69,66 @@ Optional parameters: - ```-B```: Number of predictions to collect before running models on them in batch. (default: 1 (don't batch)) - ```-T```: Internal Tensorflow `predict()` batch size if you want something different from the `-B` value. (default: the `-B` value) - ```-V```: Enable verbose logging during run - -**Batching Considerations:** When setting the batching parameters, be mindful of the system and gpu memory of the machine you -are running the script on. Feel free to experiment, but some reasonable `-B` numbers would be 64/128. - -Batching Performance Benchmarks: - -| Type | Speed | -| -------- | ----------- | -| n1-standard-2 CPU (GCP) | ~800 per hour | -| CPU (2019 MacBook Pro) | ~3,000 per hour | -| K80 GPU (GCP) | ~25,000 per hour | -| V100 GPU (GCP) | ~150,000 per hour | - -Details of SpliceAI INFO field: - -| ID | Description | -| -------- | ----------- | -| ALLELE | Alternate allele | -| SYMBOL | Gene symbol | -| DS_AG | Delta score (acceptor gain) | -| DS_AL | Delta score (acceptor loss) | -| DS_DG | Delta score (donor gain) | -| DS_DL | Delta score (donor loss) | -| DP_AG | Delta position (acceptor gain) | -| DP_AL | Delta position (acceptor loss) | -| DP_DG | Delta position (donor gain) | -| DP_DL | Delta position (donor loss) | + - ```-t```: Specify a location to create the temporary files + - ```-G```: Specify the GPU(s) to run on : either indexed (eg : 0,2) or 'all'. (default: 'all') + - ```-S```: Simulate *n* multiple GPUs on a single physical device. Used for development only, currently all values above 2 crashed due to memory issues. (default: 0) + - ```-P```: Port to use when connecting to the socket (default: 54677, only used in batch mode). + +**Batching Considerations:** + +When setting the batching parameters, be mindful of the system and gpu memory of the machine you +are running the script on. Feel free to experiment, but some reasonable `-T` numbers would be 64/128. CPU memory is larger, and increasing `-B` might further improve performance. + +*Batching Performance Benchmarks:* +- Input data: GATK generated WES sample with ~ 90K variants in genome build GRCh37. +- Total predictions made : 174,237 +- invitae v2 mainly implements logic to prioritize full batches while predicting +- settings : + - invitae & invitae v2 : B = T = 64 + - invitae v2 optimal : on V100 : B = 4096 ; T = 256 -- on K80/GeForce : B = 4096 ; T = 64 + +*Benchmark results* + +| Type | Implementation | Total Time | Speed (predictions / hour) | +|--------------------------------------|-----------------------|------------|----------------------------| +| CPU (intel i5-8365U)a | illumina | ~100h | ~1000 pred/h | +| | invitae | ~39h | ~4500 pred/h | +| | invitae v2 | ~35h | ~5000 pred/h | +| | invitae v2 optimal | ~35h | ~5000 pred/h | +| K80 GPU (AWS p2.large) | illuminab | ~25 h | ~7000 pred/h | +| | invitae | 242m | ~43,000 pred / h | +| | invitae v2 | 213m | ~50,000 pred / h | +| | invitae v2 optimal | 188 m | ~56,000 pred / h | +| GeForce RTX 2070 SUPER GPU (desktop) | illuminab | ~10 h | ~ 17,000 pred/h | +| | invitae | 76m | ~137,000 pred / h | +| | invitae v2 | 63m | ~166,000 pred / h | +| | invitae v2 optimal | 52m | ~200,000 pred / h | +| V100 GPU (AWS p3.xlarge) | illuminab | ~10h | ~18,000 pred/h | +| | invitae | 78m | ~135,000 pred / h | +| | invitae v2 | 54m | ~190,000 pred / h | +| | invitae v2 optimal | 31 m | ~335,000 pred / h | + + +(a) : Extrapolated from first 500 variants + +(b) : Illumina implementation showed a memory leak with the installed versions of tf/keras/.... Values extrapolated from incomplete runs at the point of OOM. + +*Note:* On a p3.8xlarge machine, hosting 4 V100 GPU's, we were able reach 1,379,505 predictions/hour ! This is a nearly linear scale-up. + +### Details of SpliceAI INFO field: + +| ID | Description | +|--------|--------------------------------| +| ALLELE | Alternate allele | +| SYMBOL | Gene symbol | +| DS_AG | Delta score (acceptor gain) | +| DS_AL | Delta score (acceptor loss) | +| DS_DG | Delta score (donor gain) | +| DS_DL | Delta score (donor loss) | +| DP_AG | Delta position (acceptor gain) | +| DP_AL | Delta position (acceptor loss) | +| DP_DG | Delta position (donor gain) | +| DP_DL | Delta position (donor loss) | Delta score of a variant, defined as the maximum of (DS_AG, DS_AL, DS_DG, DS_DL), ranges from 0 to 1 and can be interpreted as the probability of the variant being splice-altering. In the paper, a detailed characterization is provided for 0.2 (high recall), 0.5 (recommended), and 0.8 (high precision) cutoffs. Delta position conveys information about the location where splicing changes relative to the variant position (positive values are downstream of the variant, negative values are upstream). @@ -133,5 +185,15 @@ donor_prob = y[0, :, 2] * Adds test cases to run a small file using a generated FASTA reference to test if the results are the same with no batching and with different batching sizes * Slightly modifies the entrypoint of running the code to allow for easier unit testing. Being able to pass in what would normally come from the argparser +**Multi-GPU support** - Geert Vandeweyer (_November 2022_) + +* Offload more code to CPU (eg np to tensor conversion) to *only* perform predictions on the GPU +* Implement queuing system to always have full batches ready for prediction +* Implement new parameter, `--tmpdir` to support a custom tmp folder to store prepped batches +* Implement socket-based client/server approach to scale over multiple GPUs + + ### Contact Kishore Jaganathan: kjaganathan@illumina.com + +Geert Vandeweyer (This implementation) : geert.vandeweyer@uza.be diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..4aefaff --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,67 @@ +###################################### +## CONTAINER FOR GPU based SpliceAI ## +###################################### + +# start from the cuda docker base +FROM nvidia/cuda:11.4.0-base-ubuntu20.04 + +LABEL version="1.3" +LABEL description="This container was tested with \ + - V100 on AWS p3.2xlarge with nvidia drivers 510.47.03 and cuda v11.6 \ + - K80 on AWS p2.xlarge with nvidia drivers 470.141.03 and cuda v11.4 \ + - Geforce RTX 2070 SUPER (local) with nvidia drivers 470.141.03 and cuda v11.4" + +LABEL author="Geert Vandeweyer" +LABEL author.email="geert.vandeweyer@uza.be" + +## needed apt packages +ARG BUILD_PACKAGES="wget git bzip2" +# needed conda packages + +ARG CONDA_PACKAGES="python=3.9.13 tensorflow-gpu=2.10.0 cuda-nvcc=11.8.89" + +## ENV SETTINGS during runtime +ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 +ENV PATH=/opt/conda/bin:$PATH +ENV DEBIAN_FRONTEND noninteractive + +# For micromamba: +SHELL ["/bin/bash", "-l", "-c"] +ENV MAMBA_ROOT_PREFIX=/opt/conda/ +ENV PATH=/opt/micromamba/bin:/opt/conda/bin:$PATH +ARG CONDA_CHANNEL="-c bioconda -c conda-forge -c nvidia" + +## INSTALL +RUN apt-get -y update && \ + apt-get -y install $BUILD_PACKAGES && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + + +# conda packages +RUN mkdir /opt/conda && \ + mkdir /opt/micromamba && \ + wget -qO - https://micromamba.snakepit.net/api/micromamba/linux-64/0.23.0 | tar -xvj -C /opt/micromamba bin/micromamba && \ + # initialize bash + micromamba shell init --shell=bash --prefix=/opt/conda && \ + # remove a statement from bashrc that prevents initialization + grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/micromamba/bashrc && \ + mv /opt/micromamba/bashrc /root/.bashrc && \ + source ~/.bashrc && \ + # activate & install base conda packag + micromamba activate && \ + micromamba install -y $CONDA_CHANNEL $CONDA_PACKAGES && \ + micromamba clean --all --yes + +# Break cache for recloning git +ARG DATE_CACHE_BREAK=$(date) + +# my fork of spliceai : has gpu optimizations +RUN cd /opt/ && \ + git clone https://github.com/geertvandeweyer/SpliceAI.git && \ + cd SpliceAI && \ + python setup.py install + +# no command given, print help. +CMD spliceai -h + diff --git a/spliceai/__main__.py b/spliceai/__main__.py index 11687d2..79b4896 100644 --- a/spliceai/__main__.py +++ b/spliceai/__main__.py @@ -5,9 +5,18 @@ import argparse import logging import pysam - -from spliceai.batch.batch import VCFPredictionBatch +import time +import tempfile +from multiprocessing import Process,Queue,Pool +from functools import partial +import shutil +import tensorflow as tf +import subprocess as sp +import os + +from spliceai.batch.batch_utils import prepare_batches, start_workers,initialize_devices from spliceai.utils import Annotator, get_delta_scores +from spliceai.batch.data_handlers import VCFWriter,VCFReader try: from sys.stdin import buffer as std_in @@ -20,30 +29,38 @@ def get_options(): parser = argparse.ArgumentParser(description='Version: 1.3.1') - parser.add_argument('-I', metavar='input', nargs='?', default=std_in, + parser.add_argument('-P', '--port', metavar='port', type=int, default=54677, + help='option to change port if several GPUs on one network (default: 54677)') + parser.add_argument('-I', '--input_data', metavar='input', nargs='?', default=std_in, help='path to the input VCF file, defaults to standard in') - parser.add_argument('-O', metavar='output', nargs='?', default=std_out, + parser.add_argument('-O', '--output_data', metavar='output', nargs='?', default=std_out, help='path to the output VCF file, defaults to standard out') - parser.add_argument('-R', metavar='reference', required=True, + parser.add_argument('-R', '--reference', metavar='reference', required=True, help='path to the reference genome fasta file') - parser.add_argument('-A', metavar='annotation', required=True, + parser.add_argument('-A', '--annotation',metavar='annotation', required=True, help='"grch37" (GENCODE V24lift37 canonical annotation file in ' 'package), "grch38" (GENCODE V24 canonical annotation file in ' 'package), or path to a similar custom gene annotation file') - parser.add_argument('-D', metavar='distance', nargs='?', default=50, + parser.add_argument('-D', '--distance', metavar='distance', nargs='?', default=50, type=int, choices=range(0, 5000), help='maximum distance between the variant and gained/lost splice ' 'site, defaults to 50') - parser.add_argument('-M', metavar='mask', nargs='?', default=0, + parser.add_argument('-M', '--mask', metavar='mask', nargs='?', default=0, type=int, choices=[0, 1], help='mask scores representing annotated acceptor/donor gain and ' 'unannotated acceptor/donor loss, defaults to 0') - parser.add_argument('-B', '--prediction-batch-size', metavar='prediction_batch_size', default=1, type=int, + parser.add_argument('-B', '--prediction_batch_size', metavar='prediction_batch_size', default=1, type=int, help='number of predictions to process at a time, note a single vcf record ' 'may have multiple predictions for overlapping genes and multiple alts') - parser.add_argument('-T', '--tensorflow-batch-size', metavar='tensorflow_batch_size', type=int, + parser.add_argument('-T', '--tensorflow_batch_size', metavar='tensorflow_batch_size', type=int, help='tensorflow batch size for model predictions') parser.add_argument('-V', '--verbose', action='store_true', help='enables verbose logging') + parser.add_argument('-t','--tmpdir', metavar='tmpdir',type=str,default='/tmp/',required=False, + help="Use Alternate location to store tmp files. (Note: B=4096 equals to roughly 15Gb of tmp files)") + parser.add_argument('-G','--gpus',metavar='gpus',type=str,default='all',required=False, + help="Number of GPUs to use for SpliceAI. Provide 'all', or comma-seperated list of GPUs to use. eg '0,2' (first and third). Defaults to 'all'") + parser.add_argument('-S', '--simulated_gpus',metavar='simulated_gpus',default='0',type=int, required=False, + help="For development: simulated logical gpus on a single physical device to simulate a multi-gpu environment") args = parser.parse_args() return args @@ -51,31 +68,145 @@ def get_options(): def main(): args = get_options() - + # logging if args.verbose: - logging.basicConfig( - format='%(asctime)s %(levelname)s %(name)s: - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - level=logging.DEBUG, - ) - - if None in [args.I, args.O, args.D, args.M]: + loglevel = logging.DEBUG + else: + loglevel = logging.INFO + logging.basicConfig( + format='%(asctime)s %(levelname)s %(name)s: - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=loglevel, + ) + # sanity check for mandatory arguments + if None in [args.input_data, args.output_data, args.distance, args.mask]: logging.error('Usage: spliceai [-h] [-I [input]] [-O [output]] -R reference -A annotation ' - '[-D [distance]] [-M [mask]] [-B [prediction_batch_size]] [-T [tensorflow_batch_size]]') + '[-D [distance]] [-M [mask]] [-B [prediction_batch_size]] [-T [tensorflow_batch_size]] [-t [tmp_location]]') exit() - - # Default the tensorflow batch size to the prediction_batch_size if it's not supplied in the args - tensorflow_batch_size = args.tensorflow_batch_size if args.tensorflow_batch_size else args.prediction_batch_size - - run_spliceai(input_data=args.I, output_data=args.O, reference=args.R, - annotation=args.A, distance=args.D, mask=args.M, - prediction_batch_size=args.prediction_batch_size, - tensorflow_batch_size=tensorflow_batch_size) - - -def run_spliceai(input_data, output_data, reference, annotation, distance, mask, prediction_batch_size, - tensorflow_batch_size): - + logging.debug(f"PORT:{args.port}") + + ## revised code for batched analysis + if args.prediction_batch_size > 1: + # initialize the GPU and setup to estimate + devices,mem_per_logical = initialize_devices(args) + # Default the tensorflow batch size to the prediction_batch_size if it's not supplied in the args + args.tensorflow_batch_size = args.tensorflow_batch_size if args.tensorflow_batch_size else args.prediction_batch_size + + # load annotation data: + ann = Annotator(args.reference, args.annotation) + logging.debug("Annotation loaded.") + # run + run_spliceai_batched(args,ann,devices,mem_per_logical) + + else: # run original code: + # load annotation + ann = Annotator(args.reference, args.annotation) + # run scoring + run_spliceai(args, ann) + + +## revised logic to allow batched tensorflow analysis on multiple GPUs +def run_spliceai_batched(args, ann,devices,mem_per_logical): + + ## GOAL + ## - launch a reader that preps & pickles input vcf + ## - launch per GPU/device, using sockets, a utility script that runs tasks from the queue on that device. + ## - communicate through sockets : server threads issue items from the queue to worker clients + ## - when all predictions are done, build the output vcf. + + + ## track start time + start_time = time.time() + ## variables: + input_data = args.input_data + output_data = args.output_data + distance = args.distance + mask = args.mask + prediction_batch_size = args.prediction_batch_size + tensorflow_batch_size = args.tensorflow_batch_size + + ## mk a temp directory + tmpdir = tempfile.mkdtemp(dir=args.tmpdir) # TemporaryDirectory(dir=args.tmpdir) + #tmpdir = tmpdir.name + logging.info("Using tmpdir : {}".format(tmpdir)) + + # creates a queue with max 10 ready-to-go batches in it. + prediction_queue = Queue(maxsize=10) + # starts processing & filling the queue. + reader_args={'ann':ann, 'args':args, 'tmpdir':tmpdir, 'prediction_queue': prediction_queue, 'nr_workers': len(devices)} + reader = Process(target=prepare_batches, kwargs=reader_args) + reader.start() + logging.debug("Reader started") + worker_clients, worker_servers, devices = start_workers(prediction_queue,tmpdir,args,devices,mem_per_logical) + logging.debug("workers started") + ## wait for everything to finish. + # => If exit codes != 0 are detected, the main process will exit with the first non-zero exit code. + while True: + # any exit codes defined and != 0 ? + exit_codes = [p.exitcode for p in worker_servers + [reader] if p.exitcode is not None] + [p.poll() for p in worker_clients if p.poll() is not None] + logging.debug("exit codes: {}".format(exit_codes)) + if any(rc != 0 for rc in exit_codes): + logging.error("Error encountered Exiting.") + # kill all processes + for p in worker_servers + [reader]: + if p.is_alive(): + p.kill() + for p in worker_clients: + if p.poll() is None: + p.kill() + # and exit + sys.exit(1) + if len(exit_codes) == len(worker_servers + [reader] + worker_clients): + break + time.sleep(30) + + # readers sends finish signal to workers + logging.info("Cleanup VCF reader") + reader.join() + logging.debug("Reader joined!") + # clients receive signal, send it to servers. + logging.info("Cleaning up workers.") + for p in worker_clients: + # subprocesses : wait() + p.wait() + logging.debug("Workers are done!") + logging.info("Waiting for servers to join.") + for p in worker_servers: + # mp processes : join() + p.join() + logging.debug("Servers are done") + + # stats without writing phase + prediction_duration = time.time() - start_time + + # write results. in/out from args, devices to get shelf names + logging.info("Writing output file") + writer = VCFWriter(args=args,tmpdir=tmpdir,devices=devices,ann=ann) + writer.process() + + # clear out tmp + shutil.rmtree(tmpdir) + ## stats + overall_duration = time.time() - start_time + preds_per_sec = writer.total_predictions / prediction_duration + preds_per_hour = preds_per_sec * 60 * 60 + logging.info("Analysis Finished. Statistics:") + logging.info("Total RunTime: {:0.2f}s".format(overall_duration)) + logging.info("Prediction RunTime: {:0.2f}s".format(prediction_duration)) + logging.info("Processed Records: {}".format(writer.total_vcf_records)) + logging.info("Processed Predictions: {}".format(writer.total_predictions)) + logging.info("Overall performance : {:0.2f} predictions/sec ; {:0.2f} predictions/hour".format(preds_per_sec, preds_per_hour)) + + +# original flow : record by record reading/predict/write +def run_spliceai(args, ann): + # assign variables + input_data = args.input_data + output_data = args_output_data + distance = args.distance + mask = args.mask + + # open infile try: vcf = pysam.VariantFile(input_data) except (IOError, ValueError) as e: @@ -94,39 +225,15 @@ def run_spliceai(input_data, output_data, reference, annotation, distance, mask, logging.error('{}'.format(e)) exit() - ann = Annotator(reference, annotation) - batch = None - - # Only use the batching code if we are batching - if prediction_batch_size > 1: - batch = VCFPredictionBatch( - ann=ann, - output=output_data, - dist=distance, - mask=mask, - prediction_batch_size=prediction_batch_size, - tensorflow_batch_size=tensorflow_batch_size, - ) - for record in vcf: - if batch: - # Add record to batch, if batch fills, then they will all be processed at once - batch.add_record(record) - else: - # If we're not batching, let's run the original code scores = get_delta_scores(record, ann, distance, mask) if len(scores) > 0: record.info['SpliceAI'] = scores output_data.write(record) - - if batch: - # Ensure we process any leftover records in the batch when we finish iterating the VCF. This - # would be a good candidate for a context manager if we removed the original non batching code above - batch.finish() - + + # close VCF vcf.close() output_data.close() - if __name__ == '__main__': main() diff --git a/spliceai/batch/batch.py b/spliceai/batch/batch.py index 12f98db..5841306 100644 --- a/spliceai/batch/batch.py +++ b/spliceai/batch/batch.py @@ -1,200 +1,181 @@ # Original source code modified to add prediction batching support by Invitae in 2021. # Modifications copyright (c) 2021 Invitae Corporation. -import collections +# Invitae source code modified to improve GPU utilization +# Modifications made by Geert Vandeweyer (Antwerp University Hospital, Belgium) + import logging import time - +import shelve import numpy as np +import os +import tensorflow as tf +import pickle +import gc +import socket +import sys +import argparse + +#from spliceai.batch.batch_utils import extract_delta_scores, get_preds +sys.path.append('../../../spliceai') +from spliceai.batch.batch_utils import get_preds, initialize_devices, initialize_one_device +from spliceai.utils import Annotator, get_delta_scores -from spliceai.batch.batch_utils import extract_delta_scores, get_preds, encode_batch_records -logger = logging.getLogger(__name__) SequenceType_REF = 0 SequenceType_ALT = 1 -BatchLookupIndex = collections.namedtuple( - 'BatchLookupIndex', 'sequence_type tensor_size batch_index' -) - -PreparedVCFRecord = collections.namedtuple( - 'PreparedVCFRecord', 'vcf_record gene_info locations' -) - - +# options : revised from __main__ +def get_options(): + + parser = argparse.ArgumentParser(description='Version: 1.3.1') + parser.add_argument('-P', '--port', metavar='port', required=True, type=int) + parser.add_argument('-R', '--reference', metavar='reference', required=True, + help='path to the reference genome fasta file') + parser.add_argument('-A', '--annotation',metavar='annotation', required=True, + help='"grch37" (GENCODE V24lift37 canonical annotation file in ' + 'package), "grch38" (GENCODE V24 canonical annotation file in ' + 'package), or path to a similar custom gene annotation file') + parser.add_argument('-T', '--tensorflow_batch_size', metavar='tensorflow_batch_size', type=int, + help='tensorflow batch size for model predictions') + parser.add_argument('-V', '--verbose', action='store_true', help='enables verbose logging') + parser.add_argument('-t','--tmpdir', metavar='tmpdir',type=str,default='/tmp/',required=False, + help="Use Alternate location to store tmp files. (Note: B=4096 equals to roughly 15Gb of tmp files)") + parser.add_argument('-d','--device',metavar='device',type=str,required=True, + help="CPU/GPU device to deploy worker on") + parser.add_argument('-S', '--simulated_gpus',metavar='simulated_gpus',default='0',type=int, required=False, + help="For development: simulated logical gpus on a single physical device to simulate a multi-gpu environment") + parser.add_argument('-M', '--mem_per_logical', metavar='mem_per_logical',default=0,type=int, required=False, + help="For simulated GPUs assign this amount of memory (Mb)") + parser.add_argument('-G','--gpus',metavar='gpus',type=str,default='all',required=False, + help="Number of GPUs to use for SpliceAI. Provide 'all', or comma-seperated list of GPUs to use. eg '0,2' (first and third). Defaults to 'all'") + args = parser.parse_args() + + return args + + +def main(): + # get arguments + args = get_options() + if args.verbose: + loglevel = logging.DEBUG + else: + loglevel = logging.INFO + logging.basicConfig( + format='%(asctime)s %(levelname)s %(name)s: - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=loglevel, + ) + logger = logging.getLogger(__name__) + + # initialize && assign device + if args.simulated_gpus > 0: + devices = [x for x in initialize_devices(args)[0] if x.name == args.device] + else: + # no simulation : expose only the requested device to tensor. + devices = initialize_one_device(args) + + + if not devices: + logger.error(f"Specified device '{args.device}' not found!") + sys.exit(1) + device = devices[0].name + with tf.device(device): + logger.info(f"Working on device {args.device}") + # initialize the VCFPredictionBatch, pass (non-masked) device name + worker = VCFPredictionBatch(args=args,logger=logger) + # start working ! + worker.process_batches() + # done. + + + + +# Class to handle predictions class VCFPredictionBatch: - def __init__(self, ann, output, dist, mask, prediction_batch_size, tensorflow_batch_size): - self.ann = ann - self.output = output - self.dist = dist - self.mask = mask - # This is the maximum number of predictions to parse/encode/predict at a time - self.prediction_batch_size = prediction_batch_size - # This is the size of the batch tensorflow will use to make the predictions - self.tensorflow_batch_size = tensorflow_batch_size - - # Batch vars - self.batches = {} - self.prepared_vcf_records = [] - - # Counts - self.batch_predictions = 0 - self.total_predictions = 0 - self.total_vcf_records = 0 - - def _clear_batch(self): - self.batch_predictions = 0 - self.batches.clear() - del self.prepared_vcf_records[:] - - def _process_batch(self): + def __init__(self, args, logger): + self.args = args + self.ann = None + self.tensorflow_batch_size = args.tensorflow_batch_size + self.tmpdir = args.tmpdir + self.device = args.device + self.logger = logger + + # store batches of predictions using 'tensor_size|batch_idx' as key. + self.shelf_preds_name = f"spliceai_preds.{self.device[1:].replace(':','_')}.shelf" + self.shelf_preds = shelve.open(os.path.join(self.tmpdir, self.shelf_preds_name)) + + # monitor the queue and submit incoming batches. + def process_batches(self): + with socket.socket() as s: + host = socket.gethostname() # locahost + port = self.args.port + try: + s.connect((host,port)) + except Exception as e: + raise(e) + # first response : server is running + res = s.recv(2048) + # then start polling queue + msg = "Ready for work..." + + # first load annotation + if not self.ann: + # load annotation + self.ann = Annotator(self.args.reference, self.args.annotation,cpu=True) + while True: + # send request for work + s.send(str.encode(msg)) + res = s.recv(2048).decode('utf-8') + # response can be a job, 'hold on' for empty queue, or 'Done' for all finished. + if res == 'Hold On': + msg = 'Ready for work...' + time.sleep(0.1) + elif res == 'Finished': + self.logger.info("Worker done. Shutting down") + break + else: + # got a batch id: + with open(os.path.join(self.tmpdir,res),'rb') as p: + data = pickle.load(p) + # remove pickled batch + os.unlink(os.path.join(self.tmpdir,res)) + # process : stats are send back as next 'ready for work' result. + try: + msg = self._process_batch(data['tensor_size'],data['batch_ix'], data['data'],data['length']) + except Exception as e: + self.logger.error(f"Error processing batch {data['tensor_size']}|{data['batch_ix']}: {repr(e)}") + # send error message back to server + msg = "Error : {}".format(repr(e)) + + # send signal to server thread to exit. + s.send(str.encode('Done')) + self.logger.info(f"Closing Worker on device {self.device}") + + + def _process_batch(self,tensor_size,batch_ix, prediction_batch,nr_preds): start = time.time() - total_batch_predictions = 0 - logger.debug('Starting process_batch') - + # Sanity check dump of batch sizes - batch_sizes = ["{}:{}".format(tensor_size, len(batch)) for tensor_size, batch in self.batches.items()] - logger.debug('Batch Sizes: {}'.format(batch_sizes)) - - # Collect each batch's predictions - batch_preds = {} - for tensor_size, batch in self.batches.items(): - # Convert list of encodings into a proper sized numpy matrix - prediction_batch = np.concatenate(batch, axis=0) - - # Run predictions - batch_preds[tensor_size] = np.mean( - get_preds(self.ann, prediction_batch, self.prediction_batch_size), axis=0 - ) - - # Iterate over original list of vcf records, reconstructing record with annotations - for prepared_record in self.prepared_vcf_records: - record_predictions = self._write_record(prepared_record, batch_preds) - total_batch_predictions += record_predictions - - self._clear_batch() - logger.debug('Predictions: {}, VCF Records: {}'.format(self.total_predictions, self.total_vcf_records)) + self.logger.debug('Tensor size : {} : batch_ix {} : nr.entries : {}'.format(tensor_size, batch_ix , nr_preds)) + + # Run predictions && add to shelf. + self.shelf_preds["{}|{}".format(tensor_size,batch_ix)] = np.mean( + get_preds(self.ann, prediction_batch, self.tensorflow_batch_size), axis=0 + ) + + # status duration = time.time() - start - preds_per_sec = total_batch_predictions / duration + preds_per_sec = nr_preds / duration preds_per_hour = preds_per_sec * 60 * 60 - logger.debug('Finished in {:0.2f}s, per sec: {:0.2f}, per hour: {:0.2f}'.format(duration, - preds_per_sec, - preds_per_hour)) - - def _write_record(self, prepared_record, batch_preds): - record = prepared_record.vcf_record - gene_info = prepared_record.gene_info - record_predictions = 0 - - all_y_ref = [] - all_y_alt = [] - - # Each prediction in the batch is located and put into the correct y - for location in prepared_record.locations: - # No prediction here - if location.tensor_size == 0: - if location.sequence_type == SequenceType_REF: - all_y_ref.append(None) - else: - all_y_alt.append(None) - continue - - # Extract the prediction from the batch into a list of predictions for this record - batch = batch_preds[location.tensor_size] - if location.sequence_type == SequenceType_REF: - all_y_ref.append(batch[[location.batch_index], :, :]) - else: - all_y_alt.append(batch[[location.batch_index], :, :]) - - delta_scores = extract_delta_scores( - all_y_ref=all_y_ref, - all_y_alt=all_y_alt, - record=record, - ann=self.ann, - dist_var=self.dist, - mask=self.mask, - gene_info=gene_info, - ) - # If there are predictions, write them to the VCF INFO section - if len(delta_scores) > 0: - record.info['SpliceAI'] = delta_scores - record_predictions += len(delta_scores) - - self.output.write(record) - return record_predictions - - def add_record(self, record): - """ - Adds a record to a batch. It'll capture the gene information for the record and - save it for later to avoid looking it up again, then it'll encode ref and alt from - the VCF record and place the encoded values into lists of matching sizes. Once the - encoded values are added, a BatchLookupIndex is created so that after the predictions - are made, it knows where to look up the corresponding prediction for the vcf record. - - Once the batch size hits it's capacity, it'll process all the predictions for the - encoded batches. - """ - - self.total_vcf_records += 1 - # Collect gene information for this record - gene_info = self.ann.get_name_and_strand(record.chrom, record.pos) - - # Keep track of how many predictions we're going to make - prediction_count = len(record.alts) * len(gene_info.genes) - self.batch_predictions += prediction_count - self.total_predictions += prediction_count - - # Collect lists of encoded ref/alt sequences - x_ref, x_alt = encode_batch_records(record, self.ann, self.dist, gene_info) - - # List of BatchLookupIndex's so we know how to lookup predictions for records from - # the batches - batch_lookup_indexes = [] - - # Process the encodings into batches - for var_type, encoded_seq in zip((SequenceType_REF, SequenceType_ALT), (x_ref, x_alt)): - - if len(encoded_seq) == 0: - # Add BatchLookupIndex with zeros so when the batch collects the outputs - # it knows that there is no prediction for this record - batch_lookup_indexes.append(BatchLookupIndex(var_type, 0, 0)) - continue - - # Iterate over the encoded sequence and drop into the correct batch by size and - # create an index to use to pull out the result after batch is processed - for row in encoded_seq: - # Extract the size of the sequence that was encoded to build a batch from - tensor_size = row.shape[1] - - # Create batch for this size - if tensor_size not in self.batches: - self.batches[tensor_size] = [] - - # Add encoded record to batch - self.batches[tensor_size].append(row) - - # Get the index of the record we just added in the batch - cur_batch_record_ix = len(self.batches[tensor_size]) - 1 - - # Store a reference so we can pull out the prediction for this item from the batches - batch_lookup_indexes.append(BatchLookupIndex(var_type, tensor_size, cur_batch_record_ix)) - - # Save the batch locations for this record on the composite object - prepared_record = PreparedVCFRecord( - vcf_record=record, gene_info=gene_info, locations=batch_lookup_indexes - ) - self.prepared_vcf_records.append(prepared_record) - - # If we're reached our threshold for the max items to process, then process the batch - if self.batch_predictions >= self.prediction_batch_size: - self._process_batch() - - def finish(self): - """ - Method to process all the remaining items that have been added to the batch. - """ - if len(self.prepared_vcf_records) > 0: - self._process_batch() + msg = 'Device {} : Finished in {:0.2f}s, per sec: {:0.2f}, per hour: {:0.2f}'.format(self.device, duration, preds_per_sec, preds_per_hour) + self.logger.debug(msg) + return msg + + + +if __name__ == '__main__': + main() diff --git a/spliceai/batch/batch_utils.py b/spliceai/batch/batch_utils.py index 310a82a..c6893ee 100644 --- a/spliceai/batch/batch_utils.py +++ b/spliceai/batch/batch_utils.py @@ -1,107 +1,226 @@ # Original source code modified to add prediction batching support by Invitae in 2021. # Modifications copyright (c) 2021 Invitae Corporation. +# Invitae source code modified to improve GPU utilization +# Modifications made by Geert Vandeweyer (Antwerp University Hospital, Belgium) + + import logging +import shelve +import os +import gc +import numpy as np +import tensorflow as tf +import pickle +import socket +from multiprocessing import Process +import subprocess +import time +import sys from spliceai.utils import get_alt_gene_delta_score, is_record_valid, get_seq, \ is_location_predictable, get_cov, get_wid, is_valid_alt_record, encode_seqs, create_unhandled_delta_score +sys.path.append('../../spliceai') +from spliceai.batch.data_handlers import VCFReader, VCFWriter + logger = logging.getLogger(__name__) + +########### +## INPUT ## +########### +## routine to create the batches for prediction. +def prepare_batches(ann, args, tmpdir, prediction_queue,nr_workers): + try: + # create the parser object + vcf_reader = VCFReader(ann=ann, + input_data=args.input_data, + prediction_batch_size=args.prediction_batch_size, + prediction_queue=prediction_queue, + tmpdir=tmpdir,dist=args.distance, + ) + # parse records + vcf_reader.add_records() + # finalize last batches + vcf_reader.finish(nr_workers) + # close the shelf. + vcf_reader.shelf_records.close() + # stats + logger.info("Read {} vcf records, queued {} predictions".format(vcf_reader.total_vcf_records, vcf_reader.total_predictions)) + except Exception as e: + logger.error(f"Error in prepare_batches: {repr(e)}") + raise(e) + + + + + +############## +## ANALYSIS ## +############## +## routine to start the worker Threads +def start_workers(prediction_queue, tmpdir, args,devices,mem_per_logical): + # start server socket + s = socket.socket() + host = socket.gethostname() # locahost + port = args.port + logger.info(f"Starting server: {host}:{port}") + + try: + s.bind((host,port)) + except Exception as e: + logger.error(f"Cannot bind to port {port} : {e}") + sys.exit(1) + s.listen(5) + # start client sockets & server threads. + clientThreads = list() + serverThreads = list() + + for device in devices: + # launch the worker. + logger.info(f"Starting worker on device {device.name}, output is available under {tmpdir}") + cmd = ["python",os.path.join(os.path.dirname(os.path.realpath(__file__)),"batch.py"),"-S",str(args.simulated_gpus),"-M",str(int(mem_per_logical)), "-t",tmpdir,"-d",device.name, '-R', args.reference, '-A', args.annotation, '-T', str(args.tensorflow_batch_size), '-P', str(args.port)] + if args.verbose: + cmd.append('-V') + logger.debug(cmd) + fh_stdout = open(tmpdir+'/'+device.name.replace('/','_').replace(':','.')+'.stdout','w') + fh_stderr = open(tmpdir+'/'+device.name.replace('/','_').replace(':','.')+'.stderr','w') + + p = subprocess.Popen(cmd ,stdout=fh_stdout, stderr=fh_stderr) + clientThreads.append(p) + ## then a new thread in the server for this connection. + client, address = s.accept() + logger.debug("Connected to : " + address[0] + ' : ' + str(address[1])) + p = Process(target=_process_server,args=(client,device.name,prediction_queue,)) + p.start() + serverThreads.append(p) + logger.debug(f"Thread {device.name} activated!") + + return clientThreads, serverThreads, devices + +# routine that runs in the server threads, issuing work to the worker_clients. +def _process_server(clientsocket,device,queue): + # initial response + clientsocket.send(str.encode('Server is online')) + while True: + msg = clientsocket.recv(2048).decode('utf-8') + if msg == 'Done': + logger.debug(f"Stopping thread {device}") + break + elif msg.startswith('Error'): + # send finish signal to worker to shut down cleanly + clientsocket.sendall(str.encode('Finished')) + # then raise the error to the main thread. + raise Exception(msg) + elif not msg == 'Ready for work...': + logger.info(msg) + # send/get new item + try: + item = queue.get(False) + except Exception as e: + #print(str(e)) + item = 'Hold On' + + # set reply + clientsocket.sendall(str.encode(str(item))) + + logger.debug(f"Closing {device} socket.") + clientsocket.close() + + + +## get tensorflow predictions using batch-based submissions (used in worker clients) def get_preds(ann, x, batch_size=32): logger.debug('Running get_preds with matrix size: {}'.format(x.shape)) - return [ - ann.models[m].predict(x, batch_size=batch_size, verbose=0) for m in range(5) - ] - - -# Heavily based on utils.get_delta_scores but only handles the validation and encoding -# of the record, but doesn't do any of the prediction or post-processing steps -def encode_batch_records(record, ann, dist_var, gene_info): - cov = get_cov(dist_var) - wid = get_wid(cov) - # If the record is not going to get a prediction, return this empty encoding - empty_encoding = ([], []) - - if not is_record_valid(record): - return empty_encoding - - seq = get_seq(record, ann, wid) - if not seq: - return empty_encoding - - if not is_location_predictable(record, seq, wid, dist_var): - return empty_encoding - - all_x_ref = [] - all_x_alt = [] - for alt_ix in range(len(record.alts)): - for gene_ix in range(len(gene_info.idxs)): - - if not is_valid_alt_record(record, alt_ix): - continue - - x_ref, x_alt = encode_seqs(record=record, - seq=seq, - ann=ann, - gene_info=gene_info, - gene_ix=gene_ix, - alt_ix=alt_ix, - wid=wid) - - all_x_ref.append(x_ref) - all_x_alt.append(x_alt) - - return all_x_ref, all_x_alt - - -# Heavily based on utils.get_delta_scores but only handles the post-processing steps after -# the models have made the predictions -def extract_delta_scores( - all_y_ref, all_y_alt, record, ann, dist_var, mask, gene_info -): - cov = get_cov(dist_var) - delta_scores = [] - pred_ix = 0 - for alt_ix in range(len(record.alts)): - for gene_ix in range(len(gene_info.idxs)): - - # Pull prediction out of batch - y_ref = all_y_ref[pred_ix] - y_alt = all_y_alt[pred_ix] - - # No prediction here - if y_ref is None or y_alt is None: - continue - - if not is_valid_alt_record(record, alt_ix): - continue - - if len(record.ref) > 1 and len(record.alts[alt_ix]) > 1: - pred_ix += 1 - delta_score = create_unhandled_delta_score(record.alts[alt_ix], gene_info.genes[gene_ix]) - delta_scores.append(delta_score) - continue - - if pred_ix >= len(all_y_ref) or pred_ix >= len(all_y_alt): - raise LookupError( - 'Prediction index {} does not exist in prediction matrices: ref({}) alt({})'.format( - pred_ix, len(all_y_ref), len(all_y_alt) - ) - ) - - delta_score = get_alt_gene_delta_score(record=record, - ann=ann, - alt_ix=alt_ix, - gene_ix=gene_ix, - y_ref=y_ref, - y_alt=y_alt, - cov=cov, - gene_info=gene_info, - mask=mask) - delta_scores.append(delta_score) - - pred_ix += 1 - - return delta_scores + try: + predictions = [ann.models[m].predict(x, batch_size=batch_size, verbose=0) for m in range(5)] + except Exception as e: + # try a smaller batch (less efficient, but lower on memory). if it crashes again : it raises. + logger.warning("TF.predict failed ({}).Retrying with smaller batch size".format(e)) + predictions = [ann.models[m].predict(x, batch_size=4, verbose=0) for m in range(5)] + # garbage collection to prevent memory overflow... + gc.collect() + return predictions + + +## initialize a single device, hide others (only for non-simulated gpus) +def initialize_one_device(args): + gpus = tf.config.list_physical_devices('GPU') + if not gpus: + return tf.config.list_logical_devices('CPU') + # get the index of specified device. + idx = None + for i in range(len(gpus)): + if gpus[i].name.replace('physical_','') == args.device: + idx = i + break + if idx is None: + logger.error("Device not found") + logger.debug(idx) + logger.debug(args.device) + logger.debug([x.name.replace('physical_','') for x in gpus]) + raise Exception(f"specified device '{args.device}' not found.") + # set visible + tf.config.set_visible_devices(gpus[idx], 'GPU') + logical_devices = tf.config.list_logical_devices('GPU') + return logical_devices + + +## management routine to initialize gpu/cpu devices and do simulated logical devices if needed +def initialize_devices(args): + ## need to simulate gpus ? + gpus = tf.config.list_physical_devices('GPU') + mem_per_logical = 0 + if gpus: + if args.simulated_gpus > 1: + logger.warning(f"Simulating {args.simulated_gpus} logical GPUs on the first physical GPU device") + try: + gpu_mem_mb = _get_gpu_memory() + except Exception as e: + logger.error(f"Could not get GPU memory (needs nvidia-smi) : {e}") + sys.exit(1) + + # Create n virtual GPUs with [available] / n GB memory each + if hasattr(args,'mem_per_logical'): + mem_per_logical = args.mem_per_logical + else: + mem_per_logical = int((gpu_mem_mb[0]-2048) / args.simulated_gpus) + + logger.info(f"Assigning {mem_per_logical}mb of GPU memory per simulated GPU.") + try: + device_list = [tf.config.LogicalDeviceConfiguration(memory_limit=mem_per_logical)] * args.simulated_gpus + tf.config.set_logical_device_configuration( + gpus[0], + device_list) + logical_gpus = tf.config.list_logical_devices('GPU') + + except RuntimeError as e: + # Virtual devices must be set before GPUs have been initialized + raise(e) + prediction_devices = tf.config.list_logical_devices('GPU') + + if not args.gpus.lower() == 'all': + idxs = [int(x) for x in args.gpus.split(',')] + prediction_devices = [prediction_devices[x] for x in idxs] + else: + # run on cpu + prediction_devices = tf.config.list_logical_devices('CPU')[0] + + logger.info("Using the following devices for prediction:") + for d in prediction_devices: + logger.info(f" - {d.name}") + + return prediction_devices, mem_per_logical + +## helper function to get gpu memory. +def _get_gpu_memory(): + command = "nvidia-smi --query-gpu=memory.free --format=csv" + memory_free_info = subprocess.check_output(command.split()).decode('ascii').split('\n')[:-1][1:] + memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] + return memory_free_values + + + + diff --git a/spliceai/batch/data_handlers.py b/spliceai/batch/data_handlers.py new file mode 100644 index 0000000..a96bce3 --- /dev/null +++ b/spliceai/batch/data_handlers.py @@ -0,0 +1,411 @@ +import logging +import shelve +import pysam +import collections +import os +import numpy as np +import pickle +import tensorflow as tf +import sys + +from spliceai.utils import get_cov, get_wid, get_seq, is_record_valid, is_location_predictable, \ + is_valid_alt_record, encode_seqs, create_unhandled_delta_score, get_alt_gene_delta_score + + +logger = logging.getLogger(__name__) + + +## CUSTOM DATA TYPES +SequenceType_REF = 0 +SequenceType_ALT = 1 + +BatchLookupIndex = collections.namedtuple( + # ref/alt size batch for this size index in current batch for this size + 'BatchLookupIndex', 'sequence_type tensor_size batch_ix batch_index' +) + +PreparedVCFRecord = collections.namedtuple( + 'PreparedVCFRecord', 'vcf_idx gene_info locations' +) + + +# class to parse input and prep batches +class VCFReader: + def __init__(self, ann, input_data, prediction_batch_size, prediction_queue, tmpdir, dist): + self.ann = ann + # This is the maximum number of predictions to parse/encode/predict at a time + self.prediction_batch_size = prediction_batch_size + # the vcf file + self.input_data = input_data + # window to consider + self.dist = dist + # Batch vars + self.batches = {} + + # Counts + self.total_predictions = 0 + self.total_vcf_records = 0 + self.batch_counters = {} + + # the queue + self.prediction_queue = prediction_queue + + # shelves to track data. + self.tmpdir = tmpdir + # track records to have order correct + logging.debug("Opening spliceai_records shelf") + try: + self.shelf_records = shelve.open(os.path.join(self.tmpdir,"spliceai_records.shelf")) + except Exception as e: + logging.error(f"Could not open shelf: {e}") + raise(e) + + + def add_records(self): + + try: + vcf = pysam.VariantFile(self.input_data) + except (IOError, ValueError) as e: + logging.error('{}'.format(e)) + raise(e) + for record in vcf: + try: + self.add_record(record) + except Exception as e: + raise(e) + vcf.close() + + + def add_record(self, record): + """ + Adds a record to a batch. It'll capture the gene information for the record and + save it for later to avoid looking it up again, then it'll encode ref and alt from + the VCF record and place the encoded values into lists of matching sizes. Once the + encoded values are added, a BatchLookupIndex is created so that after the predictions + are made, it knows where to look up the corresponding prediction for the vcf record. + + Once the batch size hits it's capacity, it'll process all the predictions for the + encoded batch. + """ + + self.total_vcf_records += 1 + # Collect gene information for this record + gene_info = self.ann.get_name_and_strand(record.chrom, record.pos) + + # Keep track of how many predictions we're going to make + prediction_count = len(record.alts) * len(gene_info.genes) + self.total_predictions += prediction_count + + # Collect lists of encoded ref/alt sequences + x_ref, x_alt = self._encode_batch_records(record, self.ann, self.dist, gene_info) + + # List of BatchLookupIndex's so we know how to lookup predictions for records from + # the batches + batch_lookup_indexes = [] + + # Process the encodings into batches + for var_type, encoded_seq in zip((SequenceType_REF, SequenceType_ALT), (x_ref, x_alt)): + + if len(encoded_seq) == 0: + # Add BatchLookupIndex with zeros so when the batch collects the outputs + # it knows that there is no prediction for this record + batch_lookup_indexes.append(BatchLookupIndex(var_type, 0, 0, 0)) + continue + + # Iterate over the encoded sequence and drop into the correct batch by size and + # create an index to use to pull out the result after batch is processed + for row in encoded_seq: + # Extract the size of the sequence that was encoded to build a batch from + tensor_size = row.shape[1] + + # Create batch for this size + if tensor_size not in self.batches: + self.batches[tensor_size] = [] + self.batch_counters[tensor_size] = 0 + + # Add encoded record to batch 'n' for tensor_size + self.batches[tensor_size].append(row) + + # Get the index of the record we just added in the batch + cur_batch_record_ix = len(self.batches[tensor_size]) - 1 + + # Store a reference so we can pull out the prediction for this item from the batches + batch_lookup_indexes.append(BatchLookupIndex(var_type, tensor_size, self.batch_counters[tensor_size] , cur_batch_record_ix)) + + # Save the batch locations for this record on the composite object + prepared_record = PreparedVCFRecord( + vcf_idx=self.total_vcf_records, gene_info=gene_info, locations=batch_lookup_indexes + ) + # add to shelf by vcf_idx + self.shelf_records[str(self.total_vcf_records)] = prepared_record + + # If we're reached our threshold for the max items to process, then process the batch + for tensor_size in self.batch_counters: + if len(self.batches[tensor_size]) >= self.prediction_batch_size: + logger.debug("Batch {} full. Adding to queue".format(tensor_size)) + # fully prep the batch outside of gpu routine... + data = np.concatenate(self.batches[tensor_size]) + concat_len = len(data) + # offload conversion of batch from np to tensor to CPU + with tf.device('CPU:0'): + data = tf.convert_to_tensor(data) + queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : data, 'length':concat_len} + with open(os.path.join(self.tmpdir,"{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])),"wb") as p: + pickle.dump(queue_item,p) + self.prediction_queue.put("{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])) + + # reset + self.batches[tensor_size] = [] + self.batch_counters[tensor_size] += 1 + + + + + def finish(self,nr_workers): + """ + Method to process all the remaining items that have been added to the batches. + """ + logger.debug("Queueing remaining batches") + for tensor_size in self.batch_counters: + if len(self.batches[tensor_size] ) > 0: + # fully prep the batch outside of gpu routine... + data = np.concatenate(self.batches[tensor_size]) + concat_len = len(data) + # offload conversion of batch from np to tensor to CPU + with tf.device('CPU:0'): + data = tf.convert_to_tensor(data) + queue_item = {'tensor_size': tensor_size, 'batch_ix': self.batch_counters[tensor_size], 'data' : data, 'length':concat_len} + with open(os.path.join(self.tmpdir,"{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])),"wb") as p: + pickle.dump(queue_item,p) + self.prediction_queue.put("{}--{}.in.pickle".format(tensor_size,self.batch_counters[tensor_size])) + # clear + self.batches[tensor_size] = [] + # all done : push finish signals (one per process device..). + logging.debug("Queueing finish signals") + for i in range(nr_workers): + self.prediction_queue.put('Finished') + + + # Heavily based on utils.get_delta_scores but only handles the validation and encoding + # of the record, but doesn't do any of the prediction or post-processing steps + def _encode_batch_records(self, record, ann, dist_var, gene_info): + cov = get_cov(dist_var) + wid = get_wid(cov) + # If the record is not going to get a prediction, return this empty encoding + empty_encoding = ([], []) + + if not is_record_valid(record): + return empty_encoding + + seq = get_seq(record, ann, wid) + if not seq: + return empty_encoding + + if not is_location_predictable(record, seq, wid, dist_var): + return empty_encoding + + all_x_ref = [] + all_x_alt = [] + for alt_ix in range(len(record.alts)): + for gene_ix in range(len(gene_info.idxs)): + + if not is_valid_alt_record(record, alt_ix): + continue + + x_ref, x_alt = encode_seqs(record=record, + seq=seq, + ann=ann, + gene_info=gene_info, + gene_ix=gene_ix, + alt_ix=alt_ix, + wid=wid) + + all_x_ref.append(x_ref) + all_x_alt.append(x_alt) + + return all_x_ref, all_x_alt + + + + +# class to parse input and prep batches +class VCFWriter: + def __init__(self, args, tmpdir, devices, ann): + + self.args = args + # the vcf file + self.input_data = args.input_data + self.output_data = args.output_data + # window to consider + self.dist = args.distance + # used devices + self.devices = [x.name for x in devices] + # shelves to track data. + self.tmpdir = tmpdir + # track records to have order correct + self.shelf_records = shelve.open(os.path.join(self.tmpdir,"spliceai_records.shelf")) + # trackers + self.total_records = 0 + self.total_predictions = 0 + # annotations. + self.ann = ann + + def process(self): + # prepare the global pred_shelf + self._aggregate_predictions() + + # open the files & update header: + self.vcf_in = pysam.VariantFile(self.input_data) + header = self.vcf_in.header + header.add_line('##INFO=') + self.vcf_out = pysam.VariantFile(self.output_data,mode='w',header=header) + + # write the output vcf. + self._write_records() + + # close shelves + self.shelf_records.close() + self.shelf_preds.close() + + # close output file. + self.vcf_in.close() + self.vcf_out.close() + + # aggregate shelves over the devices + def _aggregate_predictions(self): + logger.debug("Aggregating device shelves") + self.shelf_preds_name = f"spliceai_preds.shelf" + self.shelf_preds = shelve.open(os.path.join(self.tmpdir, self.shelf_preds_name)) + for device in self.devices: + device_shelf_name = f"spliceai_preds.{device[1:].replace(':','_')}.shelf" + device_shelf_preds = shelve.open(os.path.join(self.tmpdir, device_shelf_name)) + for x in device_shelf_preds: + self.shelf_preds[x] = device_shelf_preds[x] + device_shelf_preds.close() + + + + # wrapper to write out all shelved variants + def _write_records(self): + logger.debug("Writing output file") + # parse vcf + line_idx = 0 + batch = [] + last_batch_key = '' + for record in self.vcf_in: + line_idx += 1 + # get prepared record by line_idx + prepared_record = self.shelf_records[str(line_idx)] + gene_info = prepared_record.gene_info + # (REF + #ALT ) * #genes (* 5 models) + self.total_predictions += (1 + len(record.alts)) * len(gene_info.genes) + + all_y_ref = [] + all_y_alt = [] + + # Each prediction in the batch is located and put into the correct y + for location in prepared_record.locations: + # No prediction here + if location.tensor_size == 0: + if location.sequence_type == SequenceType_REF: + all_y_ref.append(None) + else: + all_y_alt.append(None) + continue + + # Extract the prediction from the batch into a list of predictions for this record + # recycle the batch variable if key is the same. + if not last_batch_key == "{}|{}".format(location.tensor_size,location.batch_ix): + last_batch_key = "{}|{}".format(location.tensor_size,location.batch_ix) + batch = self.shelf_preds[last_batch_key] + + if location.sequence_type == SequenceType_REF: + all_y_ref.append(batch[[location.batch_index], :, :]) + else: + all_y_alt.append(batch[[location.batch_index], :, :]) + # get delta scores + delta_scores = self._extract_delta_scores( + all_y_ref=all_y_ref, + all_y_alt=all_y_alt, + record=record, + gene_info=gene_info, + ) + + # If there are predictions, write them to the VCF INFO section + if len(delta_scores) > 0: + record.info['SpliceAI'] = delta_scores + + self.vcf_out.write(record) + # close shelf again + self.total_vcf_records = line_idx + + + # Heavily based on utils.get_delta_scores but only handles the post-processing steps after + # the models have made the predictions + def _extract_delta_scores(self, all_y_ref, all_y_alt, record, gene_info): + # variables: + dist_var = self.dist + ann = self.ann + mask = self.args.mask + + cov = get_cov(dist_var) + delta_scores = [] + pred_ix = 0 + for alt_ix in range(len(record.alts)): + for gene_ix in range(len(gene_info.idxs)): + + + # Pull prediction out of batch + try: + y_ref = all_y_ref[pred_ix] + y_alt = all_y_alt[pred_ix] + except IndexError: + logger.warn("No data for record below, alt_ix {} : gene_ix {} : pred_ix {}".format(alt_ix, gene_ix,pred_ix)) + logger.warn(record) + continue + except Exception as e: + logger.error("Predction error: {}".format(e)) + logger.error(record) + raise e + + # No prediction here + if y_ref is None or y_alt is None: + continue + + if not is_valid_alt_record(record, alt_ix): + continue + + if len(record.ref) > 1 and len(record.alts[alt_ix]) > 1: + pred_ix += 1 + delta_score = create_unhandled_delta_score(record.alts[alt_ix], gene_info.genes[gene_ix]) + delta_scores.append(delta_score) + continue + + if pred_ix >= len(all_y_ref) or pred_ix >= len(all_y_alt): + raise LookupError( + 'Prediction index {} does not exist in prediction matrices: ref({}) alt({})'.format( + pred_ix, len(all_y_ref), len(all_y_alt) + ) + ) + + delta_score = get_alt_gene_delta_score(record=record, + ann=ann, + alt_ix=alt_ix, + gene_ix=gene_ix, + y_ref=y_ref, + y_alt=y_alt, + cov=cov, + gene_info=gene_info, + mask=mask) + delta_scores.append(delta_score) + + pred_ix += 1 + + return delta_scores + + + + diff --git a/spliceai/utils.py b/spliceai/utils.py index 0f79669..c45cf30 100644 --- a/spliceai/utils.py +++ b/spliceai/utils.py @@ -1,6 +1,9 @@ # Original source code modified to add prediction batching support by Invitae in 2021. # Modifications copyright (c) 2021 Invitae Corporation. +# Invitae source code modified to improve GPU utilization +# Modifications made by Geert Vandeweyer (Antwerp University Hospital, Belgium) + import collections from pkg_resources import resource_filename @@ -8,15 +11,16 @@ import numpy as np from pyfaidx import Fasta from keras.models import load_model +import tensorflow as tf import logging - +import gc GeneInfo = collections.namedtuple('GeneInfo', 'genes strands idxs') class Annotator: - def __init__(self, ref_fasta, annotations): + def __init__(self, ref_fasta, annotations,cpu=True): if annotations == 'grch37': annotations = resource_filename(__name__, 'annotations/grch37.txt') @@ -46,9 +50,13 @@ def __init__(self, ref_fasta, annotations): except IOError as e: logging.error('{}'.format(e)) exit() - paths = ('models/spliceai{}.h5'.format(x) for x in range(1, 6)) - self.models = [load_model(resource_filename(__name__, x)) for x in paths] + # use CPU memory for loading models, to prevent gpu memory allocation. + if cpu: + with tf.device('CPU:0'): + self.models = [load_model(resource_filename(__name__, x)) for x in paths] + else: + self.models = [load_model(resource_filename(__name__, x)) for x in paths] def get_name_and_strand(self, chrom, pos): @@ -215,8 +223,8 @@ def get_delta_scores(record, ann, dist_var, mask): alt_ix=alt_ix, wid=wid) - y_ref = np.mean([ann.models[m].predict(x_ref) for m in range(5)], axis=0) - y_alt = np.mean([ann.models[m].predict(x_alt) for m in range(5)], axis=0) + y_ref = np.mean([ann.models[m].predict(x_ref,verbose=0) for m in range(5)], axis=0) + y_alt = np.mean([ann.models[m].predict(x_alt,verbose=0) for m in range(5)], axis=0) delta_score = get_alt_gene_delta_score(record=record, ann=ann, @@ -228,7 +236,7 @@ def get_delta_scores(record, ann, dist_var, mask): gene_info=gene_info, mask=mask) delta_scores.append(delta_score) - + gc.collect() return delta_scores