diff --git a/.dockstore.yml b/.dockstore.yml index 2ac8bbda057..617875994f4 100644 --- a/.dockstore.yml +++ b/.dockstore.yml @@ -136,3 +136,36 @@ workflows: - master tags: - /.*/ + - name: permutect_call_variants_with_uda + subclass: WDL + primaryDescriptorPath: /scripts/permutect/call_variants_with_uda.wdl +# testParameterFiles: +# - /scripts/pathseq/wdl/pathseq_pipeline_template.json +# filters: +# branches: +# - master +# - je_updateCondaEnvironment +# tags: +# - /.*/ + - name: permutect_make_training_dataset + subclass: WDL + primaryDescriptorPath: /scripts/permutect/make_training_dataset.wdl + # testParameterFiles: + # - /scripts/pathseq/wdl/pathseq_pipeline_template.json +# filters: +# branches: +# - master +# - je_updateCondaEnvironment +# tags: +# - /.*/ + - name: permutect_train_base_model + subclass: WDL + primaryDescriptorPath: /scripts/permutect/permutect_train_base_model.wdl + # testParameterFiles: + # - /scripts/pathseq/wdl/pathseq_pipeline_template.json +# filters: +# branches: +# - master +# - je_updateCondaEnvironment +# tags: +# - /.*/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 65abb0cb939..8070d8dd96f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,6 +39,17 @@ RUN rm /etc/apt/sources.list.d/google-cloud-sdk.list && \ apt-get -y autoremove && \ rm -rf /var/lib/apt/lists/* +# Install CUDA drivers +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb && \ + dpkg -i cuda-keyring_1.0-1_all.deb && \ + apt-get update && \ + apt-get -y install cuda-drivers && \ + apt-get -y clean && \ + apt-get -y autoclean && \ + apt-get -y autoremove && \ + rm -rf /var/lib/apt/lists/* + + WORKDIR /gatk RUN chmod -R a+rw /gatk diff --git a/scripts/gatkcondaenv.yml.template b/scripts/gatkcondaenv.yml.template index 51d00ce99e8..fe82e3b9e67 100644 --- a/scripts/gatkcondaenv.yml.template +++ b/scripts/gatkcondaenv.yml.template @@ -17,26 +17,34 @@ channels: # if channels other than conda-forge are added and the channel order is changed (note that conda channel_priority is currently set to flexible), # verify that key dependencies are installed from the correct channel - conda-forge +- pytorch +- nvidia dependencies: # core python dependencies - conda-forge::python=3.10.13 # do not update without good reason - conda-forge:pip=23.3.1 -- conda-forge:blas=1.0=mkl # our official environment uses MKL versions of various packages; if other versions are desired, users should edit this YML accordingly +- conda-forge:blas=1. # our official environment uses MKL versions of various packages; if other versions are desired, users should edit this YML accordingly - conda-forge::numpy=1.26.2 - conda-forge::pymc=5.10.1 - conda-forge::pytensor=2.18.3 - conda-forge::scipy=1.11.4 - conda-forge::h5py=3.10.0 -- conda-forge::pytorch=2.1.0=*mkl*100 +- pytorch::pytorch=2.1.0 - conda-forge::pytorch-lightning=2.4.0 # supports Pytorch >= 2.1 and <= 2.4, used by NVScoreVariants +- pytorch::pytorch-cuda=12.1 - conda-forge::scikit-learn=1.3.2 - conda-forge::matplotlib=3.8.2 - conda-forge::pandas=2.1.3 - conda-forge::tqdm=4.66.1 - conda-forge::dill=0.3.7 # used for pickling lambdas in TrainVariantAnnotationsModel - conda-forge::biopython=1.84 # used by NVScoreVariants +- conda-forge::tensorboard=2.8.0 +- conda-forge::setuptools>=57.0.0 +- conda-forge::psutil>=5.9.2 +# - conda-forge::protobuf<3.20,>=3.9.2 # Protobuf constraint for TensorFlow compatibility +- conda-forge::intervaltree~=3.1.0 # core R dependencies; these should only be used for plotting and do not take precedence over core python dependencies! - r-base=4.3.1 @@ -52,7 +60,9 @@ dependencies: # other python dependencies; these should be removed after functionality is moved into Java code - bioconda::pysam=0.22.0 - conda-forge::pyvcf=0.6.8 +- bioconda::cyvcf2~=0.30.15 # pip installs should be avoided, as pip may not respect the dependencies found by the conda solver - pip: - - gatkPythonPackageArchive.zip + - mmap-ninja>=0.2.4 + - gatkPythonPackageArchive.zip \ No newline at end of file diff --git a/scripts/permutect/call_variants_with_uda.wdl b/scripts/permutect/call_variants_with_uda.wdl new file mode 100644 index 00000000000..9a4657ab8ad --- /dev/null +++ b/scripts/permutect/call_variants_with_uda.wdl @@ -0,0 +1,276 @@ +version 1.0 + +# run Mutect2 to get both training AND test datasets. The training dataset is preprocessed and combined with +# high-quality labeled data to make a UDA dataset, then used to train an artifact model. The test dataset is used +# for the posterior model and filtering. + +# note that the artifact model can be trained before the Mutect2 workflow runs FilterMutectCalls + +import "https://api.firecloud.org/ga4gh/v1/tools/davidben:mutect2/versions/18/plain-WDL/descriptor" as m2 +import "permutect-uda-dataset.wdl" as uda +import "permutect-train-artifact-model.wdl" as training +import "permutect-call-variants.wdl" as calling + +workflow CallVariantsWithUDA { + input { + # basic inputs for Mutect2 + File? intervals + File? masked_intervals + File ref_fasta + File ref_fai + File ref_dict + File primary_bam + File primary_bai + File? control_bam + File? control_bai + File? gnomad + File? gnomad_idx + String? m2_extra_args + File? dragstr_model + Boolean make_bamout = false + Boolean compress_vcfs = false + + # Mutect2 filtering + Boolean skip_m2_filtering + File? variants_for_contamination + File? variants_for_contamination_idx + File? realignment_index_bundle + String? realignment_extra_args + Boolean? run_orientation_bias_mixture_model_filter + + # preprocessing arguments + Int chunk_size + + # training arguments for both artifact model and posterior model + Int batch_size + Int inference_batch_size + Int num_workers + Int? gpu_count + Int? training_mem + + # UDA training arguments + File base_model + File source_train_tar + String source_edit_type = "keep_everything" + String target_edit_type = "unlabel_everything" + Int num_epochs + Int num_calibration_epochs + Float dropout_p + Array[Int] aggregation_layers + Array[Int] calibration_layers + String? training_extra_args + Boolean learn_artifact_spectra + Float? genomic_span + + # Permutect filtering / posterior model + File? test_dataset_truth_vcf # used for evaluation + File? test_dataset_truth_vcf_idx + Int? num_spectrum_iterations + Float? spectrum_learning_rate + String? permutect_filtering_extra_args + String bcftools_docker = "us.gcr.io/broad-dsde-methods/davidben/bcftools" + File? obscene_hack_leave_unset + + + # runtime + String gatk_docker + String permutect_docker + File? gatk_override + String basic_bash_docker = "ubuntu:16.04" + Int scatter_count + Int preemptible = 2 + Int max_retries = 1 + Int small_task_cpu = 2 + Int small_task_mem = 4 + Int small_task_disk = 100 + Int boot_disk_size = 12 + Int learn_read_orientation_mem = 8000 + Int filter_alignment_artifacts_mem = 9000 + String? gcs_project_for_requester_pays + + # Use as a last resort to increase the disk given to every task in case of ill behaving data + Int emergency_extra_disk = 0 + } + + # note: we make both training and test datasets + # note: for speed we may skip filtering in order to begin UDA artifact model training immediately + # the only M2 filtering we may need is contamination, and that may be skipped + call m2.Mutect2 { + input: + intervals = intervals, + masked_intervals = masked_intervals, + ref_fasta = ref_fasta, + ref_fai = ref_fai, + ref_dict = ref_dict, + tumor_reads = primary_bam, + tumor_reads_index = primary_bai, + normal_reads = control_bam, + normal_reads_index = control_bai, + gnomad = gnomad, + gnomad_idx = gnomad_idx, + variants_for_contamination = variants_for_contamination, + variants_for_contamination_idx = variants_for_contamination_idx, + realignment_index_bundle = realignment_index_bundle, + realignment_extra_args = realignment_extra_args, + run_orientation_bias_mixture_model_filter = run_orientation_bias_mixture_model_filter, + m2_extra_args = m2_extra_args, + dragstr_model = dragstr_model, + make_bamout = make_bamout, + make_permutect_training_dataset = true, + make_permutect_test_dataset = true, + permutect_test_dataset_truth_vcf = test_dataset_truth_vcf, + permutect_test_dataset_truth_vcf_idx = test_dataset_truth_vcf_idx, + skip_filtering = skip_m2_filtering, + gatk_docker = gatk_docker, + gatk_override = gatk_override, + scatter_count = scatter_count, + preemptible = preemptible, + max_retries = max_retries, + small_task_cpu = small_task_cpu, + small_task_mem = small_task_mem, + small_task_disk = small_task_disk, + boot_disk_size = boot_disk_size, + gcs_project_for_requester_pays = gcs_project_for_requester_pays, + emergency_extra_disk = emergency_extra_disk + } + + # preprocess the training data from Mutect2 + call Preprocess { + input: + training_dataset = select_first([Mutect2.permutect_training_dataset]), + chunk_size = chunk_size, + permutect_docker = permutect_docker + } + + # combine the source_tar and preprocessed training data into a UDA dataset + call uda.PermutectUDADataset { + input: + source_train_tar = source_train_tar, + target_train_tar = Preprocess.train_tar, + source_edit_type = source_edit_type, + target_edit_type = target_edit_type, + chunk_size = chunk_size, + permutect_docker = permutect_docker, + preemptible = 0, + max_retries = 0 + } + + # train an artifact model on the UDA dataset + call training.TrainPermutect { + input: + train_tar = PermutectUDADataset.uda_train_tar, + base_model = base_model, + num_epochs = num_epochs, + num_calibration_epochs = num_calibration_epochs, + batch_size = batch_size, + inference_batch_size = inference_batch_size, + num_workers = num_workers, + mem = training_mem, + gpu_count = gpu_count, + dropout_p = dropout_p, + aggregation_layers = aggregation_layers, + calibration_layers = calibration_layers, + extra_args = training_extra_args, + learn_artifact_spectra = learn_artifact_spectra, + genomic_span = genomic_span, + permutect_docker = permutect_docker, + preemptible = 0, + max_retries = 0 + } + + # we already ran M2 so we don't need the entire calling workflow, just the post-M2 parts of it + call calling.SplitMultiallelics { + input: + input_vcf = Mutect2.output_vcf, + input_vcf_idx = Mutect2.output_vcf_idx, + ref_fasta = ref_fasta, + ref_fai = ref_fai, + ref_dict = ref_dict, + bcftools_docker = bcftools_docker + } + + call calling.IndexVCF as IndexAfterSplitting { + input: + unindexed_vcf = SplitMultiallelics.output_vcf, + gatk_docker = gatk_docker + } + + call calling.PermutectFiltering { + input: + mutect2_vcf = IndexAfterSplitting.vcf, + mutect2_vcf_idx = IndexAfterSplitting.vcf_index, + permutect_model = TrainPermutect.artifact_model, + test_dataset = select_first([Mutect2.permutect_test_dataset]), + contigs_table = Mutect2.permutect_contigs_table, + maf_segments = Mutect2.maf_segments, + mutect_stats = Mutect2.mutect_stats, + batch_size = batch_size, + num_workers = num_workers, + gpu_count = gpu_count, + num_spectrum_iterations = num_spectrum_iterations, + spectrum_learning_rate = spectrum_learning_rate, + chunk_size = chunk_size, + permutect_filtering_extra_args = permutect_filtering_extra_args, + permutect_docker = permutect_docker, + } + + + call calling.IndexVCF as IndexAfterFiltering { + input: + unindexed_vcf = PermutectFiltering.output_vcf, + gatk_docker = gatk_docker + } + + output { + File? bamout = Mutect2.bamout + File? bamout_index = Mutect2.bamout_index + File mutect_stats = Mutect2.mutect_stats + File permutect_contigs_table = Mutect2.permutect_contigs_table + File permutect_read_groups_table = Mutect2.permutect_read_groups_table + File train_tar = Preprocess.train_tar + File training_tensorboard_tar = TrainPermutect.training_tensorboard_tar + File output_vcf = IndexAfterFiltering.vcf + File output_vcf_idx = IndexAfterFiltering.vcf_index + File calling_tensorboard_tar = PermutectFiltering.tensorboard_report + } + +} + +task Preprocess { + input { + File training_dataset + Int chunk_size + Int? source_label + + String permutect_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 16000 + Int command_mem = machine_mem - 500 + + command <<< + set -e + + gatk PermutectPreprocessDataset --training-datasets ~{training_dataset} --chunk-size ~{chunk_size} ~{"--sources " + source_label} --output train.tar + >>> + + runtime { + docker: permutect_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + " SSD" + preemptible: select_first([preemptible, 2]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + } + + output { + File train_tar = "train.tar" + } +} \ No newline at end of file diff --git a/scripts/permutect/dragstr_calibration.wdl b/scripts/permutect/dragstr_calibration.wdl new file mode 100644 index 00000000000..b74776ab2ce --- /dev/null +++ b/scripts/permutect/dragstr_calibration.wdl @@ -0,0 +1,85 @@ +version 1.0 + +## NOTE: this is essentially copied from https://github.com/broadinstitute/warp/blob/develop/tasks/broad/DragenTasks.wdl +## with minor modifications + +workflow DragstrCalibration { + input { + File ref_fasta + File ref_fai + File ref_dict + File reads + File reads_index + } + + call CalibrateDragstrModel { + input: + ref_fasta = ref_fasta, + ref_fai = ref_fai, + ref_dict = ref_dict, + reads = reads, + reads_index = reads_index + } + + output { + File dragstr_model = CalibrateDragstrModel.dragstr_model + } +} + +task CalibrateDragstrModel { + input { + File ref_fasta + File ref_fai + File ref_dict + File reads + File reads_index + + String docker = "us.gcr.io/broad-gatk/gatk:4.5.0.0" + Int preemptible_tries = 1 + Int threads = 4 + Int? disk_space + Int? mem + } + + String parallel_args = "--threads " + threads + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 4000 + Int command_mem = machine_mem - 500 + + parameter_meta{ + ref_fasta: {localization_optional: true} + ref_fai: {localization_optional: true} + ref_dict: {localization_optional: true} + reads: {localization_optional: true} + reads_index: {localization_optional: true} + } + + command <<< + set -x + + gatk ComposeSTRTableFile \ + -R ~{ref_fasta} \ + -O str_table.tsv + + + gatk --java-options "-Xmx~{command_mem}m" CalibrateDragstrModel \ + -R ~{ref_fasta} \ + -I ~{reads} \ + -str str_table.tsv \ + -O params.dragstr \ + ~{parallel_args} + >>> + + runtime { + docker: docker + disks: "local-disk " + select_first([disk_space, 100]) + " SSD" + memory: machine_mem + " MB" + preemptible: preemptible_tries + cpu: threads + } + + output { + File dragstr_model = "params.dragstr" + } +} \ No newline at end of file diff --git a/scripts/permutect/make_training_dataset.wdl b/scripts/permutect/make_training_dataset.wdl new file mode 100644 index 00000000000..62a8bb1225a --- /dev/null +++ b/scripts/permutect/make_training_dataset.wdl @@ -0,0 +1,142 @@ +version 1.0 + +# run Mutect2 without filtering to get plain text training data, then run preprocess_dataset. + +import "https://api.firecloud.org/ga4gh/v1/tools/davidben:mutect2/versions/18/plain-WDL/descriptor" as m2 + +workflow MakeTrainingDataset { + input { + # basic inputs + File? intervals + File? masked_intervals + File ref_fasta + File ref_fai + File ref_dict + File reads + File reads_index + + File? gnomad + File? gnomad_idx + + # extra arguments + String? m2_extra_args + + # preprocessing arguments + Int chunk_size + + # additional modes and outputs + File? dragstr_model + Boolean make_bamout = false + Boolean compress_vcfs = false + File? permutect_training_dataset_truth_vcf + File? permutect_training_dataset_truth_vcf_idx + + + # runtime + String gatk_docker + String permutect_docker + File? gatk_override + String basic_bash_docker = "ubuntu:16.04" + Int scatter_count + Int preemptible = 2 + Int max_retries = 1 + Int small_task_cpu = 2 + Int small_task_mem = 4 + Int small_task_disk = 100 + Int boot_disk_size = 12 + Int learn_read_orientation_mem = 8000 + Int filter_alignment_artifacts_mem = 9000 + String? gcs_project_for_requester_pays + + # Use as a last resort to increase the disk given to every task in case of ill behaving data + Int emergency_extra_disk = 0 + } + + call m2.Mutect2 { + input: + intervals = intervals, + masked_intervals = masked_intervals, + ref_fasta = ref_fasta, + ref_fai = ref_fai, + ref_dict = ref_dict, + tumor_reads = reads, + tumor_reads_index = reads_index, + gnomad = gnomad, + gnomad_idx = gnomad_idx, + m2_extra_args = m2_extra_args, + dragstr_model = dragstr_model, + make_bamout = make_bamout, + make_permutect_training_dataset = true, + permutect_training_dataset_truth_vcf = permutect_training_dataset_truth_vcf, + permutect_training_dataset_truth_vcf_idx = permutect_training_dataset_truth_vcf_idx, + skip_filtering = true, + gatk_docker = gatk_docker, + gatk_override = gatk_override, + scatter_count = scatter_count, + preemptible = preemptible, + max_retries = max_retries, + small_task_cpu = small_task_cpu, + small_task_mem = small_task_mem, + small_task_disk = small_task_disk, + boot_disk_size = boot_disk_size, + gcs_project_for_requester_pays = gcs_project_for_requester_pays, + emergency_extra_disk = emergency_extra_disk + } + + call Preprocess { + input: + training_dataset = select_first([Mutect2.permutect_training_dataset]), + chunk_size = chunk_size, + permutect_docker = permutect_docker + } + + output { + File? bamout = Mutect2.bamout + File? bamout_index = Mutect2.bamout_index + File mutect_stats = Mutect2.mutect_stats + File permutect_contigs_table = Mutect2.permutect_contigs_table + File permutect_read_groups_table = Mutect2.permutect_read_groups_table + File train_tar = Preprocess.train_tar + } + +} + +task Preprocess { + input { + File training_dataset + Int chunk_size + Int? source_label + + String permutect_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + Boolean use_ssd = true + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 16000 + Int command_mem = machine_mem - 500 + + command <<< + set -e + + gatk PermutectPreprocessDataset --training-datasets ~{training_dataset} --chunk-size ~{chunk_size} ~{"--sources " + source_label} --output train.tar + >>> + + runtime { + docker: permutect_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + if use_ssd then " SSD" else " HDD" + preemptible: select_first([preemptible, 2]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + } + + output { + File train_tar = "train.tar" + } +} \ No newline at end of file diff --git a/scripts/permutect/permutect.wdl b/scripts/permutect/permutect.wdl new file mode 100644 index 00000000000..94a16271ad0 --- /dev/null +++ b/scripts/permutect/permutect.wdl @@ -0,0 +1,283 @@ +version 1.0 + +import "https://api.firecloud.org/ga4gh/v1/tools/davidben:mutect2/versions/17/plain-WDL/descriptor" as m2 + +workflow Permutect { + input { + File permutect_model + + File? intervals + File? masks + File ref_fasta + File ref_fai + File ref_dict + Int scatter_count + Int? num_spectrum_iterations + Float? spectrum_learning_rate + File primary_bam + File primary_bai + File? control_bam + File? control_bai + File? gnomad + File? gnomad_idx + File? variants_for_contamination + File? variants_for_contamination_idx + File? realignment_index_bundle + File? dragstr_model + String? realignment_extra_args + Boolean? run_orientation_bias_mixture_model_filter + String? m2_extra_args + String? split_intervals_extra_args + Int batch_size + Int num_workers + Int? gpu_count + Int chunk_size + File? test_dataset_truth_vcf # used for evaluation + File? test_dataset_truth_vcf_idx + + String? permutect_filtering_extra_args + String gatk_docker + String bcftools_docker + File? gatk_override + String permutect_docker + Int? preemptible + Int? max_retries + File? obscene_hack_leave_unset + } + + call m2.Mutect2 { + input: + make_permutect_training_dataset = false, + make_permutect_test_dataset = true, + permutect_training_dataset_truth_vcf = test_dataset_truth_vcf, + permutect_training_dataset_truth_vcf_idx = test_dataset_truth_vcf_idx, + intervals = intervals, + masked_intervals = masks, + ref_fasta = ref_fasta, + ref_fai = ref_fai, + ref_dict = ref_dict, + tumor_reads = primary_bam, + tumor_reads_index = primary_bai, + normal_reads = if control_bam == "" then obscene_hack_leave_unset else control_bam, + normal_reads_index = if control_bam == "" then obscene_hack_leave_unset else control_bai, + + scatter_count = scatter_count, + gnomad = gnomad, + gnomad_idx = gnomad_idx, + variants_for_contamination = variants_for_contamination, + variants_for_contamination_idx = variants_for_contamination_idx, + realignment_index_bundle = realignment_index_bundle, + realignment_extra_args = realignment_extra_args, + dragstr_model = dragstr_model, + run_orientation_bias_mixture_model_filter = run_orientation_bias_mixture_model_filter, + m2_extra_args = m2_extra_args, + make_bamout = false, + + gatk_docker = gatk_docker, + gatk_override = gatk_override, + preemptible = preemptible, + max_retries = max_retries + } + + call SplitMultiallelics { + input: + input_vcf = Mutect2.output_vcf, + input_vcf_idx = Mutect2.output_vcf_idx, + ref_fasta = ref_fasta, + ref_fai = ref_fai, + ref_dict = ref_dict, + bcftools_docker = bcftools_docker + } + + call IndexVCF as IndexAfterSplitting { + input: + unindexed_vcf = SplitMultiallelics.output_vcf, + gatk_docker = gatk_docker + } + + call PermutectFiltering { + input: + mutect2_vcf = IndexAfterSplitting.vcf, + mutect2_vcf_idx = IndexAfterSplitting.vcf_index, + permutect_model = permutect_model, + test_dataset = select_first([Mutect2.permutect_test_dataset]), + contigs_table = Mutect2.permutect_contigs_table, + maf_segments = Mutect2.maf_segments, + mutect_stats = Mutect2.mutect_stats, + batch_size = batch_size, + num_workers = num_workers, + gpu_count = gpu_count, + num_spectrum_iterations = num_spectrum_iterations, + spectrum_learning_rate = spectrum_learning_rate, + chunk_size = chunk_size, + permutect_filtering_extra_args = permutect_filtering_extra_args, + permutect_docker = permutect_docker, + } + + call IndexVCF as IndexAfterFiltering { + input: + unindexed_vcf = PermutectFiltering.output_vcf, + gatk_docker = gatk_docker + } + + output { + File output_vcf = IndexAfterFiltering.vcf + File output_vcf_idx = IndexAfterFiltering.vcf_index + File tensorboard_report = PermutectFiltering.tensorboard_report + File test_dataset = select_first([Mutect2.permutect_test_dataset]) + File mutect2_vcf = Mutect2.output_vcf + File mutect2_vcf_idx = Mutect2.output_vcf_idx + } +} + + task PermutectFiltering { + input { + File permutect_model + File test_dataset + File contigs_table + File mutect2_vcf + File mutect2_vcf_idx + File? maf_segments + File? normal_maf_segments + File mutect_stats + Int? num_spectrum_iterations + Float? spectrum_learning_rate + Int batch_size + Int num_workers + Int? gpu_count + Int chunk_size + String? permutect_filtering_extra_args + + String permutect_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 16000 + Int command_mem = machine_mem - 500 + + command <<< + # set -e + genomic_span=`grep "callable" ~{mutect_stats} | while read name value; do echo $value; done` + + gatk PermutectFilterVariants --input ~{mutect2_vcf} --test-dataset ~{test_dataset} \ + --permutect-model ~{permutect_model} \ + --contigs-table ~{contigs_table} \ + --output permutect-filtered.vcf \ + --tensorboard-dir tensorboard \ + --batch-size ~{batch_size} --num-workers ~{num_workers} --chunk-size ~{chunk_size} \ + ~{" --num-spectrum-iterations " + num_spectrum_iterations} \ + ~{" --spectrum-learning-rate " + spectrum_learning_rate} \ + ~{" --maf-segments " + maf_segments} ~{" --normal-maf-segments " + normal_maf_segments} \ + --genomic-span $genomic_span ~{permutect_filtering_extra_args} + + tar cvf tensorboard.tar tensorboard/ + >>> + + runtime { + docker: permutect_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + " SSD" + preemptible: select_first([preemptible, 0]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 2]) + gpuType: "nvidia-tesla-t4" + gpuCount: select_first([gpu_count, 1]) + nvidiaDriverVersion: "535.183.01" + zones : ["us-central1-a", "us-central1-b", "us-central1-c"] + } + + output { + File output_vcf = "permutect-filtered.vcf" + File tensorboard_report = "tensorboard.tar" + } +} + +task SplitMultiallelics { + input { + File input_vcf + File input_vcf_idx + File ref_fasta + File ref_fai + File ref_dict + String bcftools_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 4000 + Int command_mem = machine_mem - 500 + + command <<< + + bcftools norm -m -any -f ~{ref_fasta} ~{input_vcf} > output.vcf + + >>> + + runtime { + docker: bcftools_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + " SSD" + preemptible: select_first([preemptible, 0]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 2]) + } + + output { + File output_vcf = "output.vcf" + } +} + +task IndexVCF { + input { + File unindexed_vcf + String gatk_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 4000 + Int command_mem = machine_mem - 500 + + command <<< + + cp ~{unindexed_vcf} indexed.vcf + + gatk --java-options "-Xmx~{command_mem}m" IndexFeatureFile -I indexed.vcf + + gatk --java-options "-Xmx~{command_mem}m" SelectVariants -V indexed.vcf -O output.vcf --lenient \ + -DGA DP -DGA AF -DGA F1R2 -DGA F2R1 -DGA FAD -DGA SB \ + -DA AS_FilterStatus -DA AS_SB_TABLE -DA ECNT -DA GERMQ -DA MBQ -DA MFRL -DA MMQ -DA MPOS + + set -e + >>> + + runtime { + docker: gatk_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + " SSD" + preemptible: select_first([preemptible, 0]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + } + + output { + File vcf = "output.vcf" + File vcf_index = "output.vcf.idx" + } +} diff --git a/scripts/permutect/permutect_edit_dataset.wdl b/scripts/permutect/permutect_edit_dataset.wdl new file mode 100644 index 00000000000..06511118144 --- /dev/null +++ b/scripts/permutect/permutect_edit_dataset.wdl @@ -0,0 +1,76 @@ +version 1.0 + + +workflow EditDataset { + input { + File train_tar + Int chunk_size + Int? new_source + String edit_type + String? extra_args + + String permutect_docker + } + + call EditDataset { + input: + train_tar = train_tar, + permutect_docker = permutect_docker, + chunk_size = chunk_size, + new_source = new_source, + edit_type = edit_type, + extra_args = extra_args + } + + output { + File edited_dataset_tarfile = EditDataset.output_dataset_tarfile + } +} + +task EditDataset { + input { + File train_tar + Int chunk_size + Int? new_source + String edit_type + String? extra_args + + String permutect_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + Boolean use_ssd = false + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 16000 + Int command_mem = machine_mem - 500 + + command <<< + set -e + + gatk PermutectEditDataset \ + --train-tar ~{sep=" --train-tar " + tdrain_tar} \ + --chunk-size ~{chunk_size} \ + {" --source " + new_source} \ + --dataset-edit ~{edit_type} \ + --output edited-dataset.tar \ + ~{extra_args} + >>> + + runtime { + docker: permutect_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + if use_ssd then " SSD" else " HDD" + preemptible: select_first([preemptible, 10]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + } + + output { + File output_dataset_tarfile = "edited_dataset.tar" + } +} diff --git a/scripts/permutect/permutect_evaluation.wdl b/scripts/permutect/permutect_evaluation.wdl new file mode 100644 index 00000000000..c7a9f12403f --- /dev/null +++ b/scripts/permutect/permutect_evaluation.wdl @@ -0,0 +1,79 @@ +version 1.0 + +workflow PermutectEvaluation { + input { + File permutect_model + File evaluation_tar + Int batch_size + Int? num_workers + + String permutect_docker + Int? preemptible + Int? max_retries + } + + + call Evaluate { + input: + evaluation_tar = evaluation_tar, + permutect_model = permutect_model, + batch_size = batch_size, + num_workers = num_workers, + permutect_docker = permutect_docker + } + + + output { + File tensorboard_tar = Evaluate.tensorboard_tar + } +} + +task Evaluate { + input { + File evaluation_tar + File permutect_model + Int batch_size + Int? num_workers + String? extra_args + + String permutect_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + Boolean use_ssd = true + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 16000 + Int command_mem = machine_mem - 500 + + command <<< + set -e + + evaluate_model \ + --evaluation_tar ~{evaluation_tar} \ + --permutect_model ~{permutect_model} \ + --batch_size ~{batch_size} \ + ~{"--num_workers " + num_workers} \ + --tensorboard_dir tensorboard \ + ~{extra_args} + + tar cvf tensorboard.tar tensorboard/ + >>> + + runtime { + docker: permutect_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + if use_ssd then " SSD" else " HDD" + preemptible: select_first([preemptible, 10]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + } + + output { + File tensorboard_tar = "tensorboard.tar" + } +} diff --git a/scripts/permutect/permutect_make_uda_dataset.wdl b/scripts/permutect/permutect_make_uda_dataset.wdl new file mode 100644 index 00000000000..4c64107e4d4 --- /dev/null +++ b/scripts/permutect/permutect_make_uda_dataset.wdl @@ -0,0 +1,99 @@ +version 1.0 + +# source and target train tar have already been preprocessed +# source needs to be given the '0' source label +# target needs to be given the '1' source label +workflow PermutectUDADataset { + input { + File source_train_tar + File target_train_tar + + # most likely KEEP_EVERYTHING for the source and UNLABEL_ARTIFACTS or UNLABEL_EVERYTHING for target + String source_edit_type + String target_edit_type + + Int chunk_size + + String permutect_docker + Int? preemptible + Int? max_retries + } + + call EditDataset as EditSource { + input: + train_tar = [source_train_tar], + new_source = 0, + edit_type = source_edit_type, + chunk_size = chunk_size, + permutect_docker = permutect_docker + } + + call EditDataset as EditTarget { + input: + train_tar = [target_train_tar], + new_source = 1, + edit_type = target_edit_type, + chunk_size = chunk_size, + permutect_docker = permutect_docker + } + + call EditDataset as Merge { + input: + train_tar = [EditSource.output_tarfile, EditTarget.output_tarfile], + edit_type = "keep_everything", + chunk_size = chunk_size, + permutect_docker = permutect_docker + } + + output { + File uda_train_tar = Merge.output_tarfile + } +} + +task EditDataset { + input { + Array[File] train_tar + Int chunk_size + Int? new_source + String edit_type + String? extra_args + + String permutect_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + Boolean use_ssd = false + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 16000 + Int command_mem = machine_mem - 500 + + command <<< + set -e + + gatk PermutectEditDataset \ + --train-tar ~{sep=' --train-tar ' train_tar} \ + --chunk-size ~{chunk_size} \ + ~{" --source " + new_source} \ + --dataset-edit ~{edit_type} \ + --output edited.tar \ + ~{extra_args} + >>> + + runtime { + docker: permutect_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + if use_ssd then " SSD" else " HDD" + preemptible: select_first([preemptible, 10]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + } + + output { + File output_tarfile = "edited.tar" + } +} \ No newline at end of file diff --git a/scripts/permutect/permutect_merge_datasets.wdl b/scripts/permutect/permutect_merge_datasets.wdl new file mode 100644 index 00000000000..550cef42c96 --- /dev/null +++ b/scripts/permutect/permutect_merge_datasets.wdl @@ -0,0 +1,76 @@ + + + + + +version 1.0 + + +workflow MergeDatasets { + input { + Array[File] train_tar + Int chunk_size + String? extra_args + + String permutect_docker + } + + call EditDataset { + input: + train_tar = train_tar, + permutect_docker = permutect_docker, + chunk_size = chunk_size, + edit_type = "keep_everything", + extra_args = extra_args + } + + output { + File merged_dataset_tarfile = EditDataset.output_dataset_tarfile + } +} + +task EditDataset { + input { + Array[File] train_tar + Int chunk_size + String edit_type + String? extra_args + + String permutect_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + Boolean use_ssd = false + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 16000 + Int command_mem = machine_mem - 500 + + command <<< + set -e + + edit_dataset \ + --train_tar ~{sep=' ' train_tar} \ + --chunk_size ~{chunk_size} \ + --dataset_edit ~{edit_type} \ + --output edited_dataset.tar \ + ~{extra_args} + >>> + + runtime { + docker: permutect_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + if use_ssd then " SSD" else " HDD" + preemptible: select_first([preemptible, 10]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + } + + output { + File output_dataset_tarfile = "edited_dataset.tar" + } +} diff --git a/scripts/permutect/permutect_preprocessing.wdl b/scripts/permutect/permutect_preprocessing.wdl new file mode 100644 index 00000000000..5f582be3e8b --- /dev/null +++ b/scripts/permutect/permutect_preprocessing.wdl @@ -0,0 +1,69 @@ +version 1.0 + + +workflow PreprocessPermutect { + input { + File training_dataset + Int chunk_size + Int? source_label + + String permutect_docker + Int? preemptible + Int? max_retries + } + + call Preprocess { + input: + training_dataset = training_dataset, + permutect_docker = permutect_docker, + preemptible = preemptible, + max_retries = max_retries, + chunk_size = chunk_size, + source_label = source_label + } + + output { + File train_tar = Preprocess.train_tar + } +} + + +task Preprocess { + input { + File training_dataset + Int chunk_size + Int? source_label + + String permutect_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + Boolean use_ssd = false + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 16000 + Int command_mem = machine_mem - 500 + + command <<< + set -e + + preprocess_dataset --training_datasets ~{training_dataset} --chunk_size ~{chunk_size} ~{"--sources " + source_label} --output train.tar + >>> + + runtime { + docker: permutect_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + if use_ssd then " SSD" else " HDD" + preemptible: select_first([preemptible, 2]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + } + + output { + File train_tar = "train.tar" + } +} \ No newline at end of file diff --git a/scripts/permutect/permutect_pruning.wdl b/scripts/permutect/permutect_pruning.wdl new file mode 100644 index 00000000000..1c701849cfc --- /dev/null +++ b/scripts/permutect/permutect_pruning.wdl @@ -0,0 +1,117 @@ +version 1.0 + + +workflow PrunePermutect { + input { + File train_tar + File base_model + Int num_epochs + Int num_calibration_epochs + Int batch_size + Int inference_batch_size + Int chunk_size + Int? num_workers + Float dropout_p + Array[Int] aggregation_layers + Array[Int] calibration_layers + String? train_m3_extra_args + + String permutect_docker + Int? preemptible + Int? max_retries + } + + call PrunePermutect { + input: + train_tar = train_tar, + base_model = base_model, + permutect_docker = permutect_docker, + preemptible = preemptible, + max_retries = max_retries, + num_epochs = num_epochs, + num_calibration_epochs = num_calibration_epochs, + batch_size = batch_size, + inference_batch_size = inference_batch_size, + chunk_size = chunk_size, + num_workers = num_workers, + dropout_p = dropout_p, + aggregation_layers = aggregation_layers, + calibration_layers = calibration_layers, + extra_args = train_m3_extra_args + } + + output { + File pruned_dataset_tarfile = PrunePermutect.pruned_dataset_tarfile + File training_tensorboard_tar = PrunePermutect.tensorboard_tar + } +} + + +task PrunePermutect { + input { + File train_tar + File base_model + + Int num_epochs + Int num_calibration_epochs + Int batch_size + Int inference_batch_size + Int chunk_size + Int? num_workers + Float dropout_p + Array[Int] aggregation_layers + Array[Int] calibration_layers + + String? extra_args + + String permutect_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 16000 + Int command_mem = machine_mem - 500 + + command <<< + set -e + + prune_dataset \ + --train_tar ~{train_tar} \ + --base_model ~{base_model} \ + --aggregation_layers ~{sep=' ' aggregation_layers} \ + --calibration_layers ~{sep=' ' calibration_layers} \ + --dropout_p ~{dropout_p} \ + --batch_size ~{batch_size} \ + --inference_batch_size ~{inference_batch_size} \ + --chunk_size ~{chunk_size} \ + ~{"--num_workers " + num_workers} \ + --num_epochs ~{num_epochs} \ + --num_calibration_epochs ~{num_calibration_epochs} \ + --output pruned_dataset.tar \ + --tensorboard_dir tensorboard \ + ~{extra_args} + + tar cvf tensorboard.tar tensorboard/ + >>> + + runtime { + docker: permutect_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + " SSD" + preemptible: select_first([preemptible, 10]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + gpuType: "nvidia-tesla-t4" + gpuCount: 1 + } + + output { + File pruned_dataset_tarfile = "pruned_dataset.tar" + File tensorboard_tar = "tensorboard.tar" + } +} \ No newline at end of file diff --git a/scripts/permutect/permutect_train_base_model.wdl b/scripts/permutect/permutect_train_base_model.wdl new file mode 100644 index 00000000000..426c4ebf898 --- /dev/null +++ b/scripts/permutect/permutect_train_base_model.wdl @@ -0,0 +1,135 @@ +version 1.0 + + +workflow TrainPermutectBaseModel { + input { + File train_tar + File? pretrained_model + Int num_epochs + Int batch_size + Int inference_batch_size + Int? num_workers + Float dropout_p + Float reweighting_range + Array[Int] read_layers + Int self_attention_hidden_dimension + Int num_self_attention_layers + Array[Int] info_layers + Array[Int] aggregation_layers + Array[String] ref_seq_layer_strings + String? extra_args + Int? gpu_count + + String permutect_docker + Int? preemptible + Int? max_retries + } + + call TrainPermutectBase { + input: + train_tar = train_tar, + pretrained_model = pretrained_model, + permutect_docker = permutect_docker, + preemptible = preemptible, + max_retries = max_retries, + num_epochs = num_epochs, + batch_size = batch_size, + inference_batch_size = inference_batch_size, + num_workers = num_workers, + gpu_count = gpu_count, + dropout_p = dropout_p, + reweighting_range = reweighting_range, + read_layers = read_layers, + self_attention_hidden_dimension = self_attention_hidden_dimension, + num_self_attention_layers = num_self_attention_layers, + info_layers = info_layers, + aggregation_layers = aggregation_layers, + ref_seq_layer_strings = ref_seq_layer_strings, + extra_args = extra_args + } + + output { + File base_model = TrainPermutectBase.base_model + File training_tensorboard_tar = TrainPermutectBase.tensorboard_tar + } +} + + +task TrainPermutectBase { + input { + File train_tar + File? pretrained_model + + Int num_epochs + Int batch_size + Int inference_batch_size + Int? num_workers + Int? gpu_count + Float dropout_p + Float reweighting_range + Array[Int] read_layers + Int self_attention_hidden_dimension + Int num_self_attention_layers + Array[Int] info_layers + Array[Int] aggregation_layers + Array[String] ref_seq_layer_strings + + String? extra_args + + String permutect_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 16000 + Int command_mem = machine_mem - 500 + + command <<< + set -e + + gatk PermutectTrainBaseModel \ + --train-tar ~{train_tar} \ + ~{"--pretrained-model " + pretrained_model} \ + --read-layers ~{sep=' --read-layers ' read_layers} \ + --self-attention-hidden-dimension ~{self_attention_hidden_dimension} \ + --num-self-attention-layers ~{num_self_attention_layers} \ + --info-layers ~{sep=' --info-layers ' info_layers} \ + --aggregation-layers ~{sep=' --aggregation-layers ' aggregation_layers} \ + --ref-seq-layer-strings ~{sep=' --ref-seq-layer-strings ' ref_seq_layer_strings} \ + --dropout-p ~{dropout_p} \ + --reweighting-range ~{reweighting_range} \ + --batch-size ~{batch_size} \ + --inference-batch-size ~{inference_batch_size} \ + ~{"--num-workers " + num_workers} \ + --num-epochs ~{num_epochs} \ + --output base_model.pt \ + --tensorboard-dir tensorboard \ + ~{extra_args} + + tar cvf tensorboard.tar tensorboard/ + >>> + + runtime { + docker: permutect_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + " SSD" + preemptible: select_first([preemptible, 0]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + gpuType: "nvidia-tesla-t4" + gpuCount: select_first([gpu_count, 1]) + nvidiaDriverVersion: "535.183.01" + zones : ["us-central1-a", "us-central1-b", "us-central1-c"] + } + + output { + File base_model = "base_model.pt" + File tensorboard_tar = "tensorboard.tar" + } +} + diff --git a/scripts/permutect/permutect_training.wdl b/scripts/permutect/permutect_training.wdl new file mode 100644 index 00000000000..70b0136a367 --- /dev/null +++ b/scripts/permutect/permutect_training.wdl @@ -0,0 +1,130 @@ +version 1.0 + + +workflow TrainPermutect { + input { + File train_tar + File base_model + Int num_epochs + Int num_calibration_epochs + Int batch_size + Int inference_batch_size + Int? num_workers + Int? gpu_count + Float dropout_p + Array[Int] aggregation_layers + Array[Int] calibration_layers + String? extra_args + Boolean learn_artifact_spectra + Float? genomic_span + + String permutect_docker + Int? preemptible + Int? max_retries + Int? mem + } + + call TrainPermutect { + input: + train_tar = train_tar, + base_model = base_model, + permutect_docker = permutect_docker, + preemptible = preemptible, + max_retries = max_retries, + mem = mem, + num_epochs = num_epochs, + num_calibration_epochs = num_calibration_epochs, + batch_size = batch_size, + inference_batch_size = inference_batch_size, + num_workers = num_workers, + gpu_count = gpu_count, + dropout_p = dropout_p, + aggregation_layers = aggregation_layers, + calibration_layers = calibration_layers, + extra_args = extra_args, + learn_artifact_spectra = learn_artifact_spectra, + genomic_span = genomic_span + } + + + output { + File artifact_model = TrainPermutect.artifact_model + File training_tensorboard_tar = TrainPermutect.tensorboard_tar + } +} + + +task TrainPermutect { + input { + File train_tar + File base_model + + Int num_epochs + Int num_calibration_epochs + Int batch_size + Int inference_batch_size + Int? num_workers + Int? gpu_count + Float dropout_p + Array[Int] aggregation_layers + Array[Int] calibration_layers + Boolean learn_artifact_spectra + Float? genomic_span + + String? extra_args + + String permutect_docker + Int? preemptible + Int? max_retries + Int? disk_space + Int? cpu + Int? mem + } + + # Mem is in units of GB but our command and memory runtime values are in MB + Int machine_mem = if defined(mem) then mem * 1000 else 16000 + Int command_mem = machine_mem - 500 + String learn_artifact_cmd = if learn_artifact_spectra then "--learn_artifact_spectra" else "" + + command <<< + set -e + + gatk PermutectTrainArtifactModel \ + --train-tar ~{train_tar} \ + --base-model ~{base_model} \ + --aggregation-layers ~{sep=' --aggregation-layers ' aggregation_layers} \ + --calibration-layers ~{sep=' --calibration-layers ' calibration_layers} \ + --dropout-p ~{dropout_p} \ + --batch-size ~{batch_size} \ + --inference-batch-size ~{inference_batch_size} \ + ~{"--num-workers " + num_workers} \ + --num-epochs ~{num_epochs} \ + --num-calibration-epochs ~{num_calibration_epochs} \ + --output artifact.pt \ + --tensorboard-dir tensorboard \ + ~{"--genomic-span " + genomic_span} \ + ~{learn_artifact_cmd} \ + ~{extra_args} + + tar cvf tensorboard.tar tensorboard/ + >>> + + runtime { + docker: permutect_docker + bootDiskSizeGb: 12 + memory: machine_mem + " MB" + disks: "local-disk " + select_first([disk_space, 100]) + " SSD" + preemptible: select_first([preemptible, 0]) + maxRetries: select_first([max_retries, 0]) + cpu: select_first([cpu, 1]) + gpuType: "nvidia-tesla-t4" + gpuCount: select_first([gpu_count, 1]) + nvidiaDriverVersion: "535.183.01" + zones : ["us-central1-a", "us-central1-b", "us-central1-c"] + } + + output { + File artifact_model = "artifact.pt" + File tensorboard_tar = "tensorboard.tar" + } +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstants.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstants.java new file mode 100644 index 00000000000..6422bfb6251 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstants.java @@ -0,0 +1,154 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import com.google.common.annotations.VisibleForTesting; +import org.broadinstitute.barclay.argparser.CommandLineArgumentParser; +import org.broadinstitute.barclay.argparser.CommandLineParser; +import org.broadinstitute.barclay.argparser.NamedArgumentDefinition; + +import java.lang.reflect.Field; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Map.entry; + +public class PermutectArgumentConstants { + + // Java-style (kebab case) without _K suffix + public static final String STATE_DICT_NAME = "model-state-dict"; + public static final String ARTIFACT_LOG_PRIORS_NAME = "artifact-log-priors"; + public static final String ARTIFACT_SPECTRA_STATE_DICT_NAME = "artifact-spectra-state-dict"; + public static final String HYPERPARAMS_NAME = "hyperparams"; + public static final String NUM_READ_FEATURES_NAME = "num-read-features"; + public static final String NUM_INFO_FEATURES_NAME = "num-info-features"; + public static final String REF_SEQUENCE_LENGTH_NAME = "ref-sequence-length"; + public static final String HIDDEN_LAYERS_NAME = "hidden-layers"; + public static final String NUM_BASE_FEATURES_NAME = "num-base-features"; + public static final String NUM_REF_ALT_FEATURES_NAME = "num-ref-alt-features"; + + public static final String SOURCES_NAME = "sources"; + public static final String SOURCE_NAME = "source"; + + public static final String INPUT_NAME = "input"; + public static final String OUTPUT_NAME = "output"; + public static final String OUTPUT_DIR_NAME = "output-dir"; + + public static final String READ_LAYERS_NAME = "read-layers"; + public static final String SELF_ATTENTION_HIDDEN_DIMENSION_NAME = "self-attention-hidden-dimension"; + public static final String NUM_SELF_ATTENTION_LAYERS_NAME = "num-self-attention-layers"; + + public static final String LEARNING_METHOD_NAME = "learning-method"; + + public static final String INFO_LAYERS_NAME = "info-layers"; + public static final String AGGREGATION_LAYERS_NAME = "aggregation-layers"; + public static final String CALIBRATION_LAYERS_NAME = "calibration-layers"; + public static final String REF_SEQ_LAYER_STRINGS_NAME = "ref-seq-layer-strings"; + public static final String DROPOUT_P_NAME = "dropout-p"; + public static final String LEARNING_RATE_NAME = "learning-rate"; + public static final String WEIGHT_DECAY_NAME = "weight-decay"; + public static final String BATCH_NORMALIZE_NAME = "batch-normalize"; + public static final String LEARN_ARTIFACT_SPECTRA_NAME = "learn-artifact-spectra"; + + public static final String TRAINING_DATASETS_NAME = "training-datasets"; + public static final String TRAIN_TAR_NAME = "train-tar"; + public static final String EVALUATION_TAR_NAME = "evaluation-tar"; + public static final String TEST_DATASET_NAME = "test-dataset"; + public static final String NORMAL_ARTIFACT_DATASETS_NAME = "normal-artifact-datasets"; + public static final String REWEIGHTING_RANGE_NAME = "reweighting-range"; + public static final String BATCH_SIZE_NAME = "batch-size"; + public static final String CHUNK_SIZE_NAME = "chunk-size"; + public static final String NUM_EPOCHS_NAME = "num-epochs"; + public static final String NUM_CALIBRATION_EPOCHS_NAME = "num-calibration-epochs"; + public static final String INFERENCE_BATCH_SIZE_NAME = "inference-batch-size"; + public static final String NUM_WORKERS_NAME = "num-workers"; + public static final String NUM_SPECTRUM_ITERATIONS_NAME = "num-spectrum-iterations"; + public static final String SPECTRUM_LEARNING_RATE_NAME = "spectrum-learning-rate"; + + public static final String DATASET_EDIT_TYPE_NAME = "dataset-edit"; + + public static final String TENSORBOARD_DIR_NAME = "tensorboard-dir"; + + public static final String INITIAL_LOG_VARIANT_PRIOR_NAME = "initial-log-variant-prior"; + public static final String INITIAL_LOG_ARTIFACT_PRIOR_NAME = "initial-log-artifact-prior"; + public static final String CONTIGS_TABLE_NAME = "contigs-table"; + public static final String GENOMIC_SPAN_NAME = "genomic-span"; + public static final String MAF_SEGMENTS_NAME = "maf-segments"; + public static final String NORMAL_MAF_SEGMENTS_NAME = "normal-maf-segments"; + public static final String GERMLINE_MODE_NAME = "germline-mode"; + public static final String NO_GERMLINE_MODE_NAME = "no-germline-mode"; + public static final String HET_BETA_NAME = "het-beta"; + + public static final String BASE_MODEL_NAME = "base-model"; + public static final String M3_MODEL_NAME = "permutect-model"; + public static final String PRETRAINED_MODEL_NAME = "pretrained-model"; + + @VisibleForTesting + static final Map PERMUTECT_PYTHON_ARGUMENT_MAP = Collections.unmodifiableMap(generateArgumentMap()); + + + /** + * Takes in the command line parser for a permutect tool and converts and returns a string list of all of the appropriate arguments + * for the wrapped python script that are A) actually present for the tool and B) have been set by the user. + * + * @param parser the command line parser for the tool in question from which to generate python arguments + */ + //TODO this might be easier done by directly taking the input arguments directly + public static List getPtyhonClassArgumentsFromToolParser(CommandLineParser parser) { + if (parser instanceof CommandLineArgumentParser argParser) { + List pythonArgs = new ArrayList<>(); + for (Map.Entry entry : PERMUTECT_PYTHON_ARGUMENT_MAP.entrySet()) { + NamedArgumentDefinition arg = argParser.getNamedArgumentDefinitionByAlias(entry.getKey()); + if (arg != null && arg.getHasBeenSet()) { // arg can be null if it is not actually a valid argument for the tool in question + pythonArgs.add("--" + entry.getValue()); + + //TODO double check the toString() method for the argument value + if (arg.isFlag()) { + continue; // flags don't have values + } else if (arg.isCollection()) { + // The python argument code for permutect expects a sequenctial list of strings following the list argument + ((Collection) arg.getArgumentValue()).forEach(value -> pythonArgs.add(value.toString())); + } else { + pythonArgs.add(arg.getArgumentValue().toString()); + } + } + } + return pythonArgs; + + } else { + throw new IllegalArgumentException("command line parser is not CommandLineArgumentParser"); + } + } + + /** + * A number of utilities to make converting from the java wrappers to the python methods as easy as possible. + */ + private static String convertToPythonStyle(String javaStyle) { + return javaStyle.replace('-', '_'); + } + + /** + * Generate the static map using reflection. + */ + public static Map generateArgumentMap() { + return Stream.of(PermutectArgumentConstants.class.getDeclaredFields()) + .filter(field -> java.lang.reflect.Modifier.isStatic(field.getModifiers()) + && java.lang.reflect.Modifier.isFinal(field.getModifiers()) + && field.getType().equals(String.class)) + .collect(Collectors.toMap( + PermutectArgumentConstants::getFieldValue, // Java-style name + field -> convertToPythonStyle(getFieldValue(field)) // Python-style name + )); + } + + /** + * Safely get the value of a static final field. + */ + private static String getFieldValue(Field field) { + try { + return (String) field.get(null); + } catch (IllegalAccessException e) { + throw new RuntimeException("Unable to access field: " + field.getName(), e); + } + } + +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectArtifactModelArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectArtifactModelArgumentCollection.java new file mode 100644 index 00000000000..005c0829d9c --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectArtifactModelArgumentCollection.java @@ -0,0 +1,38 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; + +import java.io.Serializable; +import java.util.List; + +public class PermutectArtifactModelArgumentCollection implements Serializable { + private static final long serialVersionUID = 1L; + @Argument( + doc = "Dimensions of hidden layers in the aggregation subnetwork, excluding the dimension of input from lower subnetworks and the dimension (1) of the output logit. Negative values indicate residual skip connections.", + fullName = PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, + optional = false + ) + public List aggregationLayers; + + @Argument( + doc = "Dimensions of hidden layers in the calibration subnetwork, excluding the dimension (1) of input logit and the dimension (also 1) of the output logit.", + fullName = PermutectArgumentConstants.CALIBRATION_LAYERS_NAME, + optional = false + ) + public List calibrationLayers; + + @Argument( + doc = "Dropout probability.", + fullName = PermutectArgumentConstants.DROPOUT_P_NAME, + optional = true + ) + public String dropoutP = "0.0"; + + @Argument( + doc = "Flag to turn on batch normalization.", + fullName = PermutectArgumentConstants.BATCH_NORMALIZE_NAME, + optional = true + ) + public boolean batchNormalize = false; + +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectBaseModelArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectBaseModelArgumentCollection.java new file mode 100644 index 00000000000..b2c5fa99292 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectBaseModelArgumentCollection.java @@ -0,0 +1,79 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; + +import java.io.Serializable; +import java.util.List; + +public class PermutectBaseModelArgumentCollection implements Serializable { + private static final long serialVersionUID = 1L; + @Argument( + doc = "Optional pretrained base model to initialize training.", + fullName = PermutectArgumentConstants.PRETRAINED_MODEL_NAME, + optional = true + ) + public String pretrainedModelName = null; + + @Argument( + doc = "Dimensions of hidden layers in the read embedding subnetwork, including the dimension of the embedding itself. Negative values indicate residual skip connections.", + fullName = PermutectArgumentConstants.READ_LAYERS_NAME, + optional = false + ) + public List readLayers = null; + + @Argument( + doc = "Hidden dimension of transformer keys and values in the self-attention layers.", + fullName = PermutectArgumentConstants.SELF_ATTENTION_HIDDEN_DIMENSION_NAME, + optional = false + ) + public String selfAttentionHiddenDimension = null; + + @Argument( + doc = "Number of symmetric gated MLP self-attention layers.", + fullName = PermutectArgumentConstants.NUM_SELF_ATTENTION_LAYERS_NAME, + optional = false + ) + public String numSelfAttentionLayers = null; + + @Argument( + doc = "Dimensions of hidden layers in the info embedding subnetwork, including the dimension of the embedding itself. Negative values indicate residual skip connections.", + fullName = PermutectArgumentConstants.INFO_LAYERS_NAME, + optional = false + ) + public List infoLayers = null; + + @Argument( + doc = "Dimensions of hidden layers in the aggregation subnetwork, excluding the dimension of input from lower subnetworks and the dimension (1) of the output logit. Negative values indicate residual skip connections.", + fullName = PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, + optional = false + ) + public List aggregationLayers = null; + + @Argument( + doc = "List of strings specifying convolution layers of the reference sequence embedding. For example: convolution/kernel_size=3/out_channels=64 pool/kernel_size=2 leaky_relu convolution/kernel_size=3/dilation=2/out_channels=5 leaky_relu flatten linear/out_features=10.", + fullName = PermutectArgumentConstants.REF_SEQ_LAYER_STRINGS_NAME, + optional = false + ) + public List refSeqLayerStrings = null; + + @Argument( + doc = "Dropout probability (default: 0.0).", + fullName = PermutectArgumentConstants.DROPOUT_P_NAME, + optional = true + ) + public String dropoutP = "0.0"; + + @Argument( + doc = "Magnitude of data augmentation by randomly weighted average of read embeddings. A value of x yields random weights between 1 - x and 1 + x (default: 0.3).", + fullName = PermutectArgumentConstants.REWEIGHTING_RANGE_NAME, + optional = true + ) + public String reweightingRange = "0.3"; + + @Argument( + doc = "Flag to turn on batch normalization.", + fullName = PermutectArgumentConstants.BATCH_NORMALIZE_NAME, + optional = true + ) + public Boolean batchNormalize = false; +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectEditDataset.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectEditDataset.java new file mode 100644 index 00000000000..788dd834fb7 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectEditDataset.java @@ -0,0 +1,72 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.ArgumentCollection; +import org.broadinstitute.barclay.argparser.BetaFeature; +import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; +import org.broadinstitute.barclay.help.DocumentedFeature; +import org.broadinstitute.hellbender.cmdline.CommandLineProgram; +import org.broadinstitute.hellbender.engine.GATKTool; +import org.broadinstitute.hellbender.utils.io.Resource; +import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor; +import picard.cmdline.programgroups.ReadDataManipulationProgramGroup; +import picard.cmdline.programgroups.VariantFilteringProgramGroup; + +import java.util.List; + +@CommandLineProgramProperties( + summary = "train the Mutect3 artifact model.", //TODO this needs to be properly labeled + oneLineSummary = "train the Mutect3 artifact modela", + programGroup = VariantFilteringProgramGroup.class +) +@DocumentedFeature +@BetaFeature +public class PermutectEditDataset extends CommandLineProgram { + public static final String PERMUTECT_EDIT_DATASET = "edit_dataset.py"; + + @Argument( + doc = "Size in bytes of output binary data files.", + fullName = PermutectArgumentConstants.CHUNK_SIZE_NAME, + optional = true + ) + public String chunkSize = String.valueOf((int) 2e9); + + @Argument( + doc = "How to modify the dataset.", + fullName = PermutectArgumentConstants.DATASET_EDIT_TYPE_NAME, + optional = false + ) + public String datasetEditType; + + @Argument( + doc = "New source integer to apply.", + fullName = PermutectArgumentConstants.SOURCE_NAME, + optional = true + ) + public String source; + + @Argument( + doc = "Tarfile(s) of training/validation datasets produced by preprocess_dataset.py.", + fullName = PermutectArgumentConstants.TRAIN_TAR_NAME, + optional = false + ) + public List trainTar; + + @Argument( + doc = "Path to pruned dataset file.", + fullName = PermutectArgumentConstants.OUTPUT_NAME, + optional = false + ) + public String output; + + @Override + protected Object doWork() { + PythonScriptExecutor executor = new PythonScriptExecutor(true); + List pythonifiedArguments = PermutectArgumentConstants.getPtyhonClassArgumentsFromToolParser(getCommandLineParser()); + + return executor.executeScript( + new Resource(PERMUTECT_EDIT_DATASET, PermutectTrainBaseModel.class), + null, + pythonifiedArguments); + } +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectFilterVariants.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectFilterVariants.java new file mode 100644 index 00000000000..2a85098f708 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectFilterVariants.java @@ -0,0 +1,169 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.ArgumentCollection; +import org.broadinstitute.barclay.argparser.BetaFeature; +import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; +import org.broadinstitute.barclay.help.DocumentedFeature; +import org.broadinstitute.hellbender.cmdline.CommandLineProgram; +import org.broadinstitute.hellbender.utils.io.Resource; +import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor; +import picard.cmdline.programgroups.VariantFilteringProgramGroup; + +import java.util.List; + +@CommandLineProgramProperties( + summary = "train the Permutect read set representation model.", + oneLineSummary = "train the Permutect read set representation model", + programGroup = VariantFilteringProgramGroup.class +) +@DocumentedFeature +@BetaFeature +public class PermutectFilterVariants extends CommandLineProgram { + + public static final String FILTER_VARIANTS = "permutect_filter_variants.py"; + + @Argument( + doc = "Unfiltered input Mutect2 VCF.", + fullName = PermutectArgumentConstants.INPUT_NAME, + optional = false + ) + public String inputName; + + @Argument( + doc = "Plain text dataset file corresponding to variants in input VCF.", + fullName = PermutectArgumentConstants.TEST_DATASET_NAME, + optional = false + ) + public String testDatasetName; + + @Argument( + doc = "Trained Permutect model from train_model.py.", + fullName = PermutectArgumentConstants.M3_MODEL_NAME, + optional = false + ) + public String m3ModelName; + + @Argument( + doc = "Table of contig names vs integer indices.", + fullName = PermutectArgumentConstants.CONTIGS_TABLE_NAME, + optional = false + ) + public String contigsTableName; + + @Argument( + doc = "Path to output filtered VCF.", + fullName = PermutectArgumentConstants.OUTPUT_NAME, + optional = false + ) + public String outputName; + + @Argument( + doc = "Path to output tensorboard directory.", + fullName = PermutectArgumentConstants.TENSORBOARD_DIR_NAME, + optional = true + ) + public String tensorboardDirName = "tensorboard"; + + @Argument( + doc = "Batch size.", + fullName = PermutectArgumentConstants.BATCH_SIZE_NAME, + optional = true + ) + public String batchSize = "64"; + + @Argument( + doc = "Number of subprocesses devoted to data loading, including reading from memory map, collating batches, and transferring to GPU.", + fullName = PermutectArgumentConstants.NUM_WORKERS_NAME, + optional = true + ) + public String numWorkers = "0"; + + @Argument( + doc = "Size in bytes of intermediate binary datasets.", + fullName = PermutectArgumentConstants.CHUNK_SIZE_NAME, + optional = true + ) + public String chunkSize = "100000"; + + @Argument( + doc = "Number of epochs for fitting allele fraction spectra.", + fullName = PermutectArgumentConstants.NUM_SPECTRUM_ITERATIONS_NAME, + optional = true + ) + public String numSpectrumIterations = "10"; + + @Argument( + doc = "Learning rate for fitting allele fraction spectra.", + fullName = PermutectArgumentConstants.SPECTRUM_LEARNING_RATE_NAME, + optional = true + ) + public String spectrumLearningRate = "0.001"; + + @Argument( + doc = "Initial value for natural log prior of somatic variants.", + fullName = PermutectArgumentConstants.INITIAL_LOG_VARIANT_PRIOR_NAME, + optional = true + ) + public String initialLogVariantPrior = "-10.0"; + + @Argument( + doc = "Initial value for natural log prior of artifacts.", + fullName = PermutectArgumentConstants.INITIAL_LOG_ARTIFACT_PRIOR_NAME, + optional = true + ) + public String initialLogArtifactPrior = "-10.0"; + + @Argument( + doc = "Number of sites considered by Mutect2, including those lacking variation or artifacts, hence absent from input dataset. Necessary for learning priors since otherwise rates of artifacts and variants would be overinflated.", + fullName = PermutectArgumentConstants.GENOMIC_SPAN_NAME, + optional = false + ) + public String genomicSpan; + + @Argument( + doc = "Copy-number segmentation file from GATK containing minor allele fractions. Useful for modeling germline variation as the minor allele fraction determines the distribution of germline allele counts.", + fullName = PermutectArgumentConstants.MAF_SEGMENTS_NAME, + optional = true + ) + public String mafSegmentsName; + + @Argument( + doc = "Copy-number segmentation file from GATK containing minor allele fractions in the normal/control sample.", + fullName = PermutectArgumentConstants.NORMAL_MAF_SEGMENTS_NAME, + optional = true + ) + public String normalMafSegmentsName; + + @Argument( + doc = "Flag for genotyping both somatic and somatic variants distinctly but considering both as non-errors (true positives), which affects the posterior threshold set by optimal F1 score.", + fullName = PermutectArgumentConstants.GERMLINE_MODE_NAME, + optional = true + ) + public boolean germlineMode = false; + + @Argument( + doc = "Beta shape parameter for germline spectrum beta binomial if we want to override binomial.", + fullName = PermutectArgumentConstants.HET_BETA_NAME, + optional = true + ) + public String hetBeta; + + @Argument( + doc = "Flag for not genotyping germline events so that the only possibilities considered are somatic, artifact, and sequencing error. This is useful for certain validation where pseudo-somatic events are created by mixing germline events at varying fractions.", + fullName = PermutectArgumentConstants.NO_GERMLINE_MODE_NAME, + optional = true + ) + public boolean noGermlineMode = false; + + @Override + protected Object doWork() { + PythonScriptExecutor executor = new PythonScriptExecutor(true); + List pythonifiedArguments = PermutectArgumentConstants.getPtyhonClassArgumentsFromToolParser(getCommandLineParser()); + + return executor.executeScript( + new Resource(FILTER_VARIANTS, PermutectTrainBaseModel.class), + null, + pythonifiedArguments); + } +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectPreprocessDataset.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectPreprocessDataset.java new file mode 100644 index 00000000000..510e59e3ba9 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectPreprocessDataset.java @@ -0,0 +1,68 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.ArgumentCollection; +import org.broadinstitute.barclay.argparser.BetaFeature; +import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; +import org.broadinstitute.barclay.help.DocumentedFeature; +import org.broadinstitute.hellbender.cmdline.CommandLineProgram; +import org.broadinstitute.hellbender.tools.copynumber.GermlineCNVCaller; +import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils; +import org.broadinstitute.hellbender.utils.io.Resource; +import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor; +import picard.cmdline.programgroups.VariantFilteringProgramGroup; + +import java.util.ArrayList; +import java.util.List; + +@CommandLineProgramProperties( + summary = "Preprocess plain text training dataset into tarfile of normalized binary data for permutect.", + oneLineSummary = "Preprocess plain text training dataset into tarfile of normalized binary data for permutect", + programGroup = VariantFilteringProgramGroup.class +) +@DocumentedFeature +@BetaFeature +public class PermutectPreprocessDataset extends CommandLineProgram { + + public static final String PERMUTECT_PREPREOCESS_DATASET_SCRIPT = "preprocess_dataset.py"; + + //TODO handle lists for this? Make it a gatk list? + @Argument( + doc = "List of plain text data files.", + fullName = PermutectArgumentConstants.TRAINING_DATASETS_NAME + ) + public String trainingDatasetName = null; + + @Argument( + doc = "Size in bytes of output binary data files. Default is 2e9.", + fullName = PermutectArgumentConstants.CHUNK_SIZE_NAME, + optional = true + ) + public String chunkSizeName = null; + + @Argument( + doc = "Integer sources corresponding to plain text data files for distinguishing different sequencing conditions.", + fullName = PermutectArgumentConstants.SOURCES_NAME, + optional = true + ) + public String sources = null; + + @Argument( + doc = "Path to output tarfile of training data.", + fullName = PermutectArgumentConstants.OUTPUT_NAME + ) + public String outputTarGz = null; + + + @Override + protected Object doWork() { + PythonScriptExecutor executor = new PythonScriptExecutor(true); + List pythonifiedArguments = PermutectArgumentConstants.getPtyhonClassArgumentsFromToolParser(getCommandLineParser()); + + return executor.executeScript( + new Resource(PERMUTECT_PREPREOCESS_DATASET_SCRIPT, PermutectTrainBaseModel.class), + null, + pythonifiedArguments); + } + +} \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainArtifactModel.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainArtifactModel.java new file mode 100644 index 00000000000..2f03272069e --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainArtifactModel.java @@ -0,0 +1,84 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.ArgumentCollection; +import org.broadinstitute.barclay.argparser.BetaFeature; +import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; +import org.broadinstitute.barclay.help.DocumentedFeature; +import org.broadinstitute.hellbender.cmdline.CommandLineProgram; +import org.broadinstitute.hellbender.utils.io.Resource; +import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor; +import picard.cmdline.programgroups.VariantFilteringProgramGroup; + +import java.util.List; + +@CommandLineProgramProperties( + summary = "train the Permutect read set representation model.", + oneLineSummary = "train the Permutect read set representation model", + programGroup = VariantFilteringProgramGroup.class +) +@DocumentedFeature +@BetaFeature +public class PermutectTrainArtifactModel extends CommandLineProgram { + + public static final String TRAIN_BASE_MODEL_PY = "train_model.py"; + + @Argument( + doc = "Flag to include artifact priors and allele fraction spectra in saved output. This is worth doing if labeled training data is available but might work poorly when Mutect3 generates weak labels based on allele fractions.", + fullName = PermutectArgumentConstants.LEARN_ARTIFACT_SPECTRA_NAME, + optional = true + ) + public boolean learnArtifactSpectra = false; + + @Argument( + doc = "Total number of sites considered by Mutect2 in all training data, including those lacking variation or artifacts, hence absent from input datasets. Necessary for learning priors since otherwise rates of artifacts and variants would be overinflated. Only required if learning artifact log priors.", + fullName = PermutectArgumentConstants.GENOMIC_SPAN_NAME, + optional = true + ) + public String genomicSpan; + + @Argument( + doc = "Tarfile of training/validation datasets produced by preprocess_dataset.py.", + fullName = PermutectArgumentConstants.TRAIN_TAR_NAME, + optional = false + ) + public String trainTarName; + + @Argument( + doc = "Base model from train_base_model.py.", + fullName = PermutectArgumentConstants.BASE_MODEL_NAME, + optional = true + ) + public String baseModelName; + + @Argument( + doc = "Path to output saved model file.", + fullName = PermutectArgumentConstants.OUTPUT_NAME, + optional = false + ) + public String outputName; + + @Argument( + doc = "Path to output tensorboard directory.", + fullName = PermutectArgumentConstants.TENSORBOARD_DIR_NAME, + optional = true + ) + public String tensorboardDirName = "tensorboard"; + + // Shared argument collections to include in arguments + @ArgumentCollection + PermutectArtifactModelArgumentCollection artifactModelArgs = new PermutectArtifactModelArgumentCollection(); + @ArgumentCollection + PermutectTrainingParamsArgumentCollection trainingParamsArgumentCollection = new PermutectTrainingParamsArgumentCollection(); + + @Override + protected Object doWork() { + PythonScriptExecutor executor = new PythonScriptExecutor(true); + List pythonifiedArguments = PermutectArgumentConstants.getPtyhonClassArgumentsFromToolParser(getCommandLineParser()); + + return executor.executeScript( + new Resource(TRAIN_BASE_MODEL_PY, PermutectTrainBaseModel.class), + null, + pythonifiedArguments); + } +} diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainBaseModel.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainBaseModel.java new file mode 100644 index 00000000000..4894112bbf3 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainBaseModel.java @@ -0,0 +1,75 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.ArgumentCollection; +import org.broadinstitute.barclay.argparser.BetaFeature; +import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; +import org.broadinstitute.barclay.help.DocumentedFeature; +import org.broadinstitute.hellbender.cmdline.CommandLineProgram; +import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; +import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils; +import org.broadinstitute.hellbender.utils.io.Resource; +import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor; +import picard.cmdline.programgroups.VariantFilteringProgramGroup; + +import java.util.ArrayList; +import java.util.List; + +@CommandLineProgramProperties( + summary = "train the Permutect read set representation model.", + oneLineSummary = "train the Permutect read set representation model", + programGroup = VariantFilteringProgramGroup.class +) +@DocumentedFeature +@BetaFeature +public class PermutectTrainBaseModel extends CommandLineProgram { + + public static final String TRAIN_BASE_MODEL_PY = "train_base_model.py"; + + @Argument( + doc = "Options [SUPERVISED, SEMISUPERVISED, SUPERVISED_CLUSTERING, AFFINE, MASK_PREDICTION, AUTOENCODER, DEEPSAD, MARS].", + fullName = PermutectArgumentConstants.LEARNING_METHOD_NAME, + optional = true + ) + public String trainingDatasetName = null; + + @Argument( + doc = "Tarfile of training/validation datasets produced by preprocess_dataset.", + fullName = PermutectArgumentConstants.TRAIN_TAR_NAME, + optional = false + ) + public String chunkSizeName = null; + + @Argument( + doc = "Output location for the saved model file.", + fullName = PermutectArgumentConstants.OUTPUT_NAME, + optional = false + ) + public String sources = null; + + @Argument( + doc = "output tensorboard directory.", + fullName = PermutectArgumentConstants.TENSORBOARD_DIR_NAME, + optional = true + ) + public String outputTarGz = null; + + // Shared argument collections to include in arguments + @ArgumentCollection + PermutectBaseModelArgumentCollection baseArgumentCollection = new PermutectBaseModelArgumentCollection(); + @ArgumentCollection + PermutectTrainingParamsArgumentCollection trainingParamsArgumentCollection = new PermutectTrainingParamsArgumentCollection(); + + @Override + protected Object doWork() { + PythonScriptExecutor executor = new PythonScriptExecutor(true); + List pythonifiedArguments = PermutectArgumentConstants.getPtyhonClassArgumentsFromToolParser(getCommandLineParser()); + + return executor.executeScript( + new Resource(TRAIN_BASE_MODEL_PY, PermutectTrainBaseModel.class), + null, + pythonifiedArguments); + } + + +} \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainingParamsArgumentCollection.java b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainingParamsArgumentCollection.java new file mode 100644 index 00000000000..15e51de7430 --- /dev/null +++ b/src/main/java/org/broadinstitute/hellbender/tools/permutect/PermutectTrainingParamsArgumentCollection.java @@ -0,0 +1,55 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; + +public class PermutectTrainingParamsArgumentCollection { + @Argument( + doc = "Learning rate for the model.", + fullName = PermutectArgumentConstants.LEARNING_RATE_NAME, + optional = true + ) + public String learningRate = "0.001"; + + @Argument( + doc = "Weight decay for the optimizer.", + fullName = PermutectArgumentConstants.WEIGHT_DECAY_NAME, + optional = true + ) + public String weightDecay = "0.0"; + + @Argument( + doc = "Batch size for training.", + fullName = PermutectArgumentConstants.BATCH_SIZE_NAME, + optional = true + ) + public String batchSize = "64"; + + @Argument( + doc = "Number of subprocesses devoted to data loading, including reading from memory map, collating batches, and transferring to GPU.", + fullName = PermutectArgumentConstants.NUM_WORKERS_NAME, + optional = true + ) + public String numWorkers = "0"; + + @Argument( + doc = "Number of epochs for primary training loop.", + fullName = PermutectArgumentConstants.NUM_EPOCHS_NAME, + optional = false + ) + public String numEpochs; + + @Argument( + doc = "Number of calibration-only epochs.", + fullName = PermutectArgumentConstants.NUM_CALIBRATION_EPOCHS_NAME, + optional = true + ) + public String numCalibrationEpochs = "0"; + + @Argument( + doc = "Batch size when performing model inference (not training).", + fullName = PermutectArgumentConstants.INFERENCE_BATCH_SIZE_NAME, + optional = true + ) + public String inferenceBatchSize = "8192"; + +} diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/__init__.py b/src/main/python/org/broadinstitute/hellbender/permutect/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/__init__.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/artifact_model.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/artifact_model.py new file mode 100644 index 00000000000..7adf7cbe57b --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/artifact_model.py @@ -0,0 +1,508 @@ +# bug before PyTorch 1.7.1 that warns when constructing ParameterList +import math +import time +import warnings +from collections import defaultdict +from typing import List + +import psutil +import torch +from torch import nn, Tensor +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from queue import PriorityQueue + + +from tqdm.autonotebook import trange, tqdm +from itertools import chain +from matplotlib import pyplot as plt + +from permutect.architecture.base_model import calculate_batch_weights, BaseModel, base_model_from_saved_dict, calculate_batch_source_weights +from permutect.architecture.gradient_reversal.module import GradientReversal +from permutect.architecture.mlp import MLP +from permutect.architecture.monotonic import MonoDense +from permutect.data.base_datum import ArtifactBatch, DEFAULT_GPU_FLOAT, DEFAULT_CPU_FLOAT +from permutect.data.artifact_dataset import ArtifactDataset +from permutect import utils, constants +from permutect.metrics.evaluation_metrics import LossMetrics, EvaluationMetrics, MAX_COUNT, round_up_to_nearest_three, \ + EmbeddingMetrics, multiple_of_three_bin_index_to_count, multiple_of_three_bin_index +from permutect.parameters import TrainingParameters, ArtifactModelParameters +from permutect.utils import Variation, Epoch, Label +from permutect.metrics import plotting + +warnings.filterwarnings("ignore", message="Setting attributes on ParameterList is not supported.") + + +WORST_OFFENDERS_QUEUE_SIZE = 100 + + +def effective_count(weights: Tensor): + return (torch.square(torch.sum(weights)) / torch.sum(torch.square(weights))).item() + + +# group rows into consecutive chunks to yield a 3D tensor, average over dim=1 to get +# 2D tensor of sums within each chunk +def sums_over_chunks(tensor2d: Tensor, chunk_size: int): + assert len(tensor2d) % chunk_size == 0 + return torch.sum(tensor2d.reshape([len(tensor2d) // chunk_size, chunk_size, -1]), dim=1) + + +class Calibration(nn.Module): + + def __init__(self, hidden_layer_sizes: List[int]): + super(Calibration, self).__init__() + + # calibration takes [logit, ref count, alt count] as input and maps it to [calibrated logit] + # it is monotonically increasing in logit, unconstrained in ref and alt count + # we initialize it to calibrated logit = input logit + + # likewise, we cap the effective alt and ref counts and input logits to avoid arbitrarily large confidence + self.max_alt = nn.Parameter(torch.tensor(20.0)) + self.max_ref = nn.Parameter(torch.tensor(20.0)) + self.max_input_logit = nn.Parameter(torch.tensor(20.0)) + + center_spacing = 1 + ref_center_spacing = 5 + + # centers of Gaussian comb featurizations + # note: even though they aren't learned and requires_grad is False, we still wrap them in nn.Parameter + # so that they can be sent to GPU recursively when the grandparent ArtifactModel is + self.alt_centers = nn.Parameter(torch.arange(start=1, end=20, step=center_spacing), requires_grad=False) + self.ref_centers = nn.Parameter(torch.arange(start=1, end=20, step=ref_center_spacing), requires_grad=False) + + # increasing in the 1st feature, logits + # logit is one feature, then the Gaussian comb for alt and ref counts is the other + self.monotonic = MonoDense(1 + len(self.ref_centers) + len(self.alt_centers), hidden_layer_sizes + [1], 1, 0) + + self.is_turned_on = True + + self.max_alt_count_for_adjustment = 20 + # after training we compute one final calibration adjustment, which depends on alt count + # the nth element is the adjustment for alt count n + # note that this is NOT a learnable parameter!!!! It is *set* but not learned!! + self.final_adjustments = nn.Parameter(torch.zeros(self.max_alt_count_for_adjustment + 1), requires_grad=False) + + def set_adjustments(self, adjustments: torch.Tensor): + current_device, current_dtype = self.final_adjustments.device, self.final_adjustments.dtype + clipped_adjustments = adjustments[:len(self.final_adjustments)] + padding_needed = len(self.final_adjustments) - len(clipped_adjustments) + padded_adjustments = torch.hstack((clipped_adjustments, clipped_adjustments[-1] * torch.ones(padding_needed))) if padding_needed else clipped_adjustments + self.final_adjustments = nn.Parameter(padded_adjustments.to(device=current_device, dtype=current_dtype), requires_grad=False) + self.max_alt_count_for_adjustment = len(adjustments) - 1 + + def calibrated_logits(self, logits_b: Tensor, ref_counts_b: Tensor, alt_counts_b: Tensor): + if self.is_turned_on: + logits_bc = torch.tanh(logits_b / self.max_input_logit)[:, None] + + ref_comb_bc = torch.softmax(-torch.square(ref_counts_b[:, None] - self.ref_centers[None, :]).float(), dim=1) + alt_comb_bc = torch.softmax(-torch.square(alt_counts_b[:, None] - self.alt_centers[None, :]).float(), dim=1) + input_2d = torch.hstack([logits_bc, ref_comb_bc, alt_comb_bc]) + calibrated_b = self.monotonic.forward(input_2d).squeeze() + + counts_for_adjustment = torch.clamp(alt_counts_b, max=self.max_alt_count_for_adjustment).long() + adjustments = self.final_adjustments[counts_for_adjustment] + + return calibrated_b + adjustments + else: # should never happen + return logits_b + + def forward(self, logits, ref_counts: Tensor, alt_counts: Tensor): + return self.calibrated_logits(logits, ref_counts, alt_counts) + + def plot_calibration(self): + device, dtype = self.final_adjustments.device, self.final_adjustments.dtype + alt_counts = [1, 3, 5, 10, 15, 20] + ref_counts = [1, 3, 5, 10, 15, 20] + logits = torch.arange(-10, 10, 0.1, device=device, dtype=dtype) + cal_fig,cal_axes = plt.subplots(len(alt_counts), len(ref_counts), sharex='all', sharey='all', + squeeze=False, figsize=(10, 6), dpi=100) + + for row_idx, alt_count in enumerate(alt_counts): + for col_idx, ref_count in enumerate(ref_counts): + calibrated = self.forward(logits, ref_count * torch.ones_like(logits, device=device, dtype=dtype), alt_count * torch.ones_like(logits, device=device, dtype=dtype)) + plotting.simple_plot_on_axis(cal_axes[row_idx, col_idx], [(logits.detach().cpu(), calibrated.detach().cpu(), "")], None, None) + + plotting.tidy_subplots(cal_fig, cal_axes, x_label="alt count", y_label="ref count", + row_labels=[str(n) for n in ref_counts], column_labels=[str(n) for n in alt_counts]) + + return cal_fig, cal_axes + + +class ArtifactModel(nn.Module): + """ + aggregation_layers: dimensions of layers for aggregation, excluding its input which is determined by the + representation model. + + output_layers: dimensions of layers after aggregation, excluding the output dimension, + which is 1 for a single logit representing artifact/non-artifact. This is not part of the aggregation layers + because we have different output layers for each variant type. + """ + + def __init__(self, params: ArtifactModelParameters, num_base_features: int, num_ref_alt_features: int, device=utils.gpu_if_available()): + super(ArtifactModel, self).__init__() + + self._device = device + self._dtype = DEFAULT_GPU_FLOAT if device != torch.device("cpu") else DEFAULT_CPU_FLOAT + self.num_base_features = num_base_features + self.num_ref_alt_features = num_ref_alt_features + self.params = params + + # feature layers before the domain adaptation source classifier splits from the artifact classifier + self.feature_layers = MLP([num_base_features] + params.aggregation_layers, batch_normalize=params.batch_normalize, dropout_p=params.dropout_p) + + # TODO: artifact classifier hidden layers are hard-coded!!! + # The [1] is for the output logit + self.artifact_classifier = MLP([self.feature_layers.output_dimension()] + [-1, -1, 1], batch_normalize=params.batch_normalize, dropout_p=params.dropout_p) + + # one Calibration module for each variant type; that is, calibration depends on both count and type + self.calibration = nn.ModuleList([Calibration(params.calibration_layers) for variant_type in Variation]) + + self.to(device=self._device, dtype=self._dtype) + + def training_parameters(self): + return chain(self.feature_layers.parameters(), self.artifact_classifier.parameters(), self.calibration.parameters()) + + def calibration_parameters(self): + return self.calibration.parameters() + + def freeze_all(self): + utils.freeze(self.parameters()) + + def set_epoch_type(self, epoch_type: utils.Epoch): + if epoch_type == utils.Epoch.TRAIN: + self.train(True) + utils.freeze(self.parameters()) + utils.unfreeze(self.training_parameters()) + else: + self.freeze_all() + + # returns 1D tensor of length batch_size of log odds ratio (logits) between artifact and non-artifact + def forward(self, batch: ArtifactBatch): + # batch has already gotten copy_to(self._device, self._dtype) + features = self.feature_layers.forward(batch.get_representations_2d()) + uncalibrated_logits = self.artifact_classifier.forward(features).reshape(batch.size()) + calibrated_logits = torch.zeros_like(uncalibrated_logits, device=self._device) + variant_types = batch.get_variant_types() + for n, _ in enumerate(Variation): + mask = (variant_types == n) + calibrated_logits += mask * self.calibration[n].forward(uncalibrated_logits, batch.get_ref_counts(), batch.get_alt_counts()) + return calibrated_logits, uncalibrated_logits, features + + def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, summary_writer: SummaryWriter, validation_fold: int = None, epochs_per_evaluation: int = None): + bce = nn.BCEWithLogitsLoss(reduction='none') # no reduction because we may want to first multiply by weights for unbalanced data + # cross entropy (with logit inputs) loss for adversarial source classification task + ce = nn.CrossEntropyLoss(reduction='none') + + num_sources = len(dataset.counts_by_source.keys()) + if num_sources == 1: + print("Training data come from a single source (this could be multiple files with the same source annotation applied in preprocessing)") + else: + sources_list = list(dataset.counts_by_source.keys()) + sources_list.sort() + assert sources_list[0] == 0, "There is no source 0" + assert sources_list[-1] == num_sources - 1, f"sources should be 0, 1, 2. . . without gaps, but sources are {sources_list}." + + print(f"Training data come from multiple sources, with counts {dataset.counts_by_source}.") + source_classifier = MLP([self.feature_layers.output_dimension()] + [-1, -1, num_sources], + batch_normalize=self.params.batch_normalize, dropout_p=self.params.dropout_p) + source_classifier.to(device=self._device, dtype=self._dtype) + source_gradient_reversal = GradientReversal(alpha=0.01) # initialize as barely active + source_gradient_reversal.to(device=self._device, dtype=self._dtype) + + # TODO: fused = is_cuda? + train_optimizer = torch.optim.AdamW(chain(self.training_parameters(), source_classifier.parameters()), lr=training_params.learning_rate, + weight_decay=training_params.weight_decay) + train_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(train_optimizer, factor=0.2, patience=5, + threshold=0.001, min_lr=(training_params.learning_rate / 100), verbose=True) + + for idx, variation_type in enumerate(utils.Variation): + print(f"For variation type {variation_type.name}, there are {int(dataset.totals[-1][Label.ARTIFACT][idx].item())} \ + artifacts, {int(dataset.totals[-1][Label.VARIANT][idx].item())} \ + non-artifacts, and {int(dataset.totals[-1][Label.UNLABELED][idx].item())} unlabeled data.") + + is_cuda = self._device.type == 'cuda' + print(f"Is CUDA available? {is_cuda}") + + validation_fold_to_use = (dataset.num_folds - 1) if validation_fold is None else validation_fold + train_loader = dataset.make_data_loader(dataset.all_but_one_fold(validation_fold_to_use), training_params.batch_size, is_cuda, training_params.num_workers) + print(f"Train loader created, memory usage percent: {psutil.virtual_memory().percent:.1f}") + valid_loader = dataset.make_data_loader([validation_fold_to_use], training_params.inference_batch_size, is_cuda, training_params.num_workers) + print(f"Validation loader created, memory usage percent: {psutil.virtual_memory().percent:.1f}") + + first_epoch, last_epoch = 1, training_params.num_epochs + training_params.num_calibration_epochs + for epoch in trange(1, last_epoch + 1, desc="Epoch"): + start_of_epoch = time.time() + print(f"Epoch {epoch}, memory usage percent: {psutil.virtual_memory().percent:.1f}") + is_calibration_epoch = epoch > training_params.num_epochs + + p = epoch - 1 + new_alpha = (2 / (1 + math.exp(-0.1 * p))) - 1 + source_gradient_reversal.set_alpha(new_alpha) + + for epoch_type in [utils.Epoch.TRAIN, utils.Epoch.VALID]: + self.set_epoch_type(epoch_type) + # in calibration epoch, freeze the model except for calibration + if is_calibration_epoch and epoch_type == utils.Epoch.TRAIN: + utils.freeze(self.parameters()) + utils.unfreeze(self.calibration_parameters()) # unfreeze calibration but everything else stays frozen + + loss_metrics = LossMetrics() # based on calibrated logits + source_prediction_loss_metrics = LossMetrics() # based on calibrated logits + uncalibrated_loss_metrics = LossMetrics() # based on uncalibrated logits + + loader = train_loader if epoch_type == utils.Epoch.TRAIN else valid_loader + loader_iter = iter(loader) + + next_batch_cpu = next(loader_iter) + next_batch = next_batch_cpu.copy_to(self._device, self._dtype, non_blocking=is_cuda) + + pbar = tqdm(range(len(loader)), mininterval=60) + for n in pbar: + # forward and backward pass on batch, which is the last iteration's prefetched "next_batch" + batch_cpu = next_batch_cpu + batch = next_batch + + # Optimization: Asynchronously send the next batch to the device while the model does work + next_batch_cpu = next(loader_iter) + next_batch = next_batch_cpu.copy_to(self._device, self._dtype, non_blocking=is_cuda) + + logits, precalibrated_logits, features = self.forward(batch) + + # one-hot prediction of sources + if num_sources > 1: + # gradient reversal means parameters before the features try to maximize source prediction loss, i.e. features + # try to forget the source, while parameters after the features try to minimize it, i.e. they try + # to achieve the adversarial task of distinguishing sources + source_prediction_logits = source_classifier.forward(source_gradient_reversal(features)) + source_prediction_probs = torch.nn.functional.softmax(source_prediction_logits, dim=-1) + source_prediction_targets = torch.nn.functional.one_hot(batch.get_sources().long(), num_sources) + source_prediction_losses = torch.sum(torch.square(source_prediction_probs - source_prediction_targets), dim=-1) + + # TODO: always by count? + source_prediction_weights = calculate_batch_source_weights(batch_cpu, dataset, by_count=is_calibration_epoch) + source_prediction_weights = source_prediction_weights.to(device=self._device, dtype=self._dtype, non_blocking=True) + else: + source_prediction_losses = torch.zeros_like(logits, device=self._device) + source_prediction_weights = torch.zeros_like(logits, device=self._device) + + # TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss + weights = calculate_batch_weights(batch_cpu, dataset, by_count=True) + weights = weights.to(device=self._device, dtype=self._dtype, non_blocking=True) + + labels = batch.get_training_labels() + uncalibrated_cross_entropies = bce(precalibrated_logits, labels) + calibrated_cross_entropies = bce(logits, labels) + labeled_losses = batch.get_is_labeled_mask() * (uncalibrated_cross_entropies + calibrated_cross_entropies) / 2 + + # unlabeled loss: entropy regularization. We use the uncalibrated logits because otherwise entropy + # regularization simply biases calibration to be overconfident. + probabilities = torch.sigmoid(precalibrated_logits) + entropies = torch.nn.functional.binary_cross_entropy_with_logits(precalibrated_logits, probabilities, reduction='none') + unlabeled_losses = (1 - batch.get_is_labeled_mask()) * entropies + + # these losses include weights and take labeled vs unlabeled into account + losses = (labeled_losses + unlabeled_losses) * weights + (source_prediction_losses * source_prediction_weights) + loss = torch.sum(losses) + + # at this point, losses, weights are on GPU (if available), while metrics are on CPU + # if we have done things right, this is okay and record_losses handles GPU <--> CPU efficiently + loss_metrics.record_losses(calibrated_cross_entropies.detach(), batch, weights * batch.get_is_labeled_mask()) + uncalibrated_loss_metrics.record_losses(uncalibrated_cross_entropies.detach(), batch, weights * batch.get_is_labeled_mask()) + uncalibrated_loss_metrics.record_losses(entropies.detach(), batch, weights * (1 - batch.get_is_labeled_mask())) + source_prediction_loss_metrics.record_losses(source_prediction_losses.detach(), batch, source_prediction_weights) + + # calibration epochs freeze the model up to calibration, so I wonder if a purely unlabeled batch + # would cause lack of gradient problems. . . + if epoch_type == utils.Epoch.TRAIN: + utils.backpropagate(train_optimizer, loss) + + # done with one epoch type -- training or validation -- for this epoch + loss_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer) + source_prediction_loss_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer, prefix="source prediction") + uncalibrated_loss_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer, prefix="uncalibrated") + if epoch_type == utils.Epoch.TRAIN: + train_scheduler.step(loss_metrics.get_labeled_loss()) + + print(f"Labeled loss for {epoch_type.name} epoch {epoch}: {loss_metrics.get_labeled_loss():.3f}") + print(f"Unlabeled loss for {epoch_type.name} epoch {epoch}: {uncalibrated_loss_metrics.get_unlabeled_loss():.3f}") + if num_sources > 1: + print(f"Adversarial source prediction loss on labeled data for {epoch_type.name} epoch {epoch}: {source_prediction_loss_metrics.get_labeled_loss():.3f}") + print(f"Adversarial source prediction loss on unlabeled data for {epoch_type.name} epoch {epoch}: {source_prediction_loss_metrics.get_unlabeled_loss():.3f}") + # done with training and validation for this epoch + print(f"End of epoch {epoch}, memory usage percent: {psutil.virtual_memory().percent:.1f}, time elapsed(s): {time.time() - start_of_epoch:.2f}") + is_last = (epoch == last_epoch) + if (epochs_per_evaluation is not None and epoch % epochs_per_evaluation == 0) or is_last: + print(f"performing evaluation on epoch {epoch}") + self.evaluate_model(epoch, dataset, train_loader, valid_loader, summary_writer, collect_embeddings=False, report_worst=False) + if is_last: + # collect data in order to do final calibration + print("collecting data for final calibration") + evaluation_metrics, _ = self.collect_evaluation_data(dataset, train_loader, valid_loader, report_worst=False) + + logit_adjustments_by_var_type_and_count_bin = evaluation_metrics.metrics[Epoch.VALID].calculate_logit_adjustments(use_harmonic_mean=False) + print("here are the logit adjustments:") + for var_type_idx, var_type in enumerate(Variation): + adjustments_by_count_bin = logit_adjustments_by_var_type_and_count_bin[var_type] + max_bin_idx = len(adjustments_by_count_bin) - 1 + max_count = multiple_of_three_bin_index_to_count(max_bin_idx) + adjustments_by_count = torch.zeros(max_count + 1) + for count in range(max_count + 1): + bin_idx = multiple_of_three_bin_index(count) + # negative sign because these are subtractive adjustments + adjustments_by_count[count] = -adjustments_by_count_bin[bin_idx] + print(f"for variant type {var_type.name} the adjustments are ") + print(adjustments_by_count.tolist()) + self.calibration[var_type_idx].set_adjustments(adjustments_by_count) + + # consider this an extra post-postprocessing/final calibration epoch, hence epoch+1 + print("doing one final evaluation after the last logit adjustment") + self.evaluate_model(epoch + 1, dataset, train_loader, valid_loader, summary_writer, collect_embeddings=True, report_worst=True) + + # note that we have not learned the AF spectrum yet + # done with training + + def evaluate_model_after_training(self, dataset: ArtifactDataset, batch_size, num_workers, summary_writer: SummaryWriter): + train_loader = dataset.make_data_loader(dataset.all_but_the_last_fold(), batch_size, self._device.type == 'cuda', num_workers) + valid_loader = dataset.make_data_loader(dataset.last_fold_only(), batch_size, self._device.type == 'cuda', num_workers) + self.evaluate_model(None, dataset, train_loader, valid_loader, summary_writer, collect_embeddings=True, report_worst=True) + + @torch.inference_mode() + def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_loader, report_worst: bool): + # the keys are tuples of (true label -- 1 for variant, 0 for artifact; rounded alt count) + worst_offenders_by_truth_and_alt_count = defaultdict(lambda: PriorityQueue(WORST_OFFENDERS_QUEUE_SIZE)) + + evaluation_metrics = EvaluationMetrics() + epoch_types = [Epoch.TRAIN, Epoch.VALID] + for epoch_type in epoch_types: + assert epoch_type == Epoch.TRAIN or epoch_type == Epoch.VALID # not doing TEST here + loader = train_loader if epoch_type == Epoch.TRAIN else valid_loader + pbar = tqdm(enumerate(loader), mininterval=60) + for n, batch_cpu in pbar: + batch = batch_cpu.copy_to(self._device, self._dtype, non_blocking=self._device.type == 'cuda') + + # these are the same weights used in training + # TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss + weights = calculate_batch_weights(batch_cpu, dataset, by_count=True) + weights = weights.to(dtype=self._dtype) # not sent to GPU! + + logits, _, _ = self.forward(batch) + # logits are calculated on the GPU (when available), so we must detach AND send back to CPU (if applicable) + pred = logits.detach().cpu() + + # note that for metrics we use batch_cpu + labels = batch_cpu.get_training_labels() + correct = ((pred > 0) == (labels > 0.5)).tolist() + + for variant_type, predicted_logit, label, is_labeled, correct_call, alt_count, variant, weight in zip( + batch_cpu.get_variant_types().tolist(), pred.tolist(), labels.tolist(), batch_cpu.get_is_labeled_mask().tolist(), correct, + batch_cpu.get_alt_counts().tolist(), batch_cpu.get_variants(), weights.tolist()): + if is_labeled < 0.5: # we only evaluate labeled data + continue + evaluation_metrics.record_call(epoch_type, variant_type, predicted_logit, label, correct_call, alt_count, weight) + if report_worst and not correct_call: + rounded_count = round_up_to_nearest_three(alt_count) + label_name = Label.ARTIFACT.name if label > 0.5 else Label.VARIANT.name + confidence = abs(predicted_logit) + + # the 0th aka highest priority element in the queue is the one with the lowest confidence + pqueue = worst_offenders_by_truth_and_alt_count[(label_name, rounded_count)] + + # clear space if this confidence is more egregious + if pqueue.full() and pqueue.queue[0][0] < confidence: + pqueue.get() # discards the least confident bad call + + if not pqueue.full(): # if space was cleared or if it wasn't full already + pqueue.put((confidence, str(variant.contig) + ":" + str( + variant.position) + ':' + variant.ref + "->" + variant.alt)) + # done with this epoch type + # done collecting data + return evaluation_metrics, worst_offenders_by_truth_and_alt_count + + @torch.inference_mode() + def evaluate_model(self, epoch: int, dataset: ArtifactDataset, train_loader, valid_loader, summary_writer: SummaryWriter, + collect_embeddings: bool = False, report_worst: bool = False): + + # self.freeze_all() + evaluation_metrics, worst_offenders_by_truth_and_alt_count = self.collect_evaluation_data(dataset, train_loader, valid_loader, report_worst) + evaluation_metrics.make_plots(summary_writer, epoch=epoch) + + if report_worst: + for (true_label, rounded_count), pqueue in worst_offenders_by_truth_and_alt_count.items(): + tag = "True label: " + true_label + ", rounded alt count: " + str(rounded_count) + + lines = [] + while not pqueue.empty(): # this goes from least to most egregious, FYI + confidence, var_string = pqueue.get() + lines.append(f"{var_string} ({confidence:.2f})") + + summary_writer.add_text(tag, "\n".join(lines), global_step=epoch) + + if collect_embeddings: + embedding_metrics = EmbeddingMetrics() + + # now go over just the validation data and generate feature vectors / metadata for tensorboard projectors (UMAP) + pbar = tqdm(enumerate(valid_loader), mininterval=60) + + for n, batch_cpu in pbar: + batch = batch_cpu.copy_to(self._device, self._dtype, non_blocking=self._device.type == 'cuda') + logits, _, _ = self.forward(batch) + pred = logits.detach().cpu() + labels = batch_cpu.get_training_labels() + correct = ((pred > 0) == (labels > 0.5)).tolist() + + label_strings = [("artifact" if label > 0.5 else "non-artifact") if is_labeled > 0.5 else "unlabeled" + for (label, is_labeled) in zip(labels.tolist(), batch_cpu.get_is_labeled_mask().tolist())] + + correct_strings = [str(correctness) if is_labeled > 0.5 else "-1" + for (correctness, is_labeled) in zip(correct, batch_cpu.get_is_labeled_mask().tolist())] + + for (metrics, embedding) in [(embedding_metrics, batch_cpu.get_representations_2d().detach())]: + metrics.label_metadata.extend(label_strings) + metrics.correct_metadata.extend(correct_strings) + metrics.type_metadata.extend([Variation(idx).name for idx in batch_cpu.get_variant_types().tolist()]) + metrics.truncated_count_metadata.extend([str(round_up_to_nearest_three(min(MAX_COUNT, alt_count))) for alt_count in batch_cpu.get_alt_counts().tolist()]) + metrics.representations.append(embedding) + embedding_metrics.output_to_summary_writer(summary_writer, epoch=epoch) + # done collecting data + + def make_dict_for_saving(self, artifact_log_priors, artifact_spectra, prefix: str = "artifact"): + return {(prefix + constants.STATE_DICT_NAME): self.state_dict(), + (prefix + constants.NUM_BASE_FEATURES_NAME): self.num_base_features, + (prefix + constants.NUM_REF_ALT_FEATURES_NAME): self.num_ref_alt_features, + (prefix + constants.HYPERPARAMS_NAME): self.params, + (prefix + constants.ARTIFACT_LOG_PRIORS_NAME): artifact_log_priors, + (prefix + constants.ARTIFACT_SPECTRA_STATE_DICT_NAME): artifact_spectra.state_dict()} + + def save(self, path, artifact_log_priors, artifact_spectra, prefix: str = "artifact"): + torch.save(self.make_dict_for_saving(artifact_log_priors, artifact_spectra, prefix), path) + + def save_with_base_model(self, base_model: BaseModel, path, artifact_log_priors, artifact_spectra): + artifact_dict = self.make_dict_for_saving(artifact_log_priors, artifact_spectra, prefix="artifact") + base_dict = base_model.make_dict_for_saving(prefix="base") + torch.save({**artifact_dict, **base_dict}, path) + + +def artifact_model_from_saved_dict(saved, prefix: str = "artifact"): + model_params = saved[prefix + constants.HYPERPARAMS_NAME] + num_base_features = saved[prefix + constants.NUM_BASE_FEATURES_NAME] + num_ref_alt_features = saved[prefix + constants.NUM_REF_ALT_FEATURES_NAME] + model = ArtifactModel(model_params, num_base_features, num_ref_alt_features) + model.load_state_dict(saved[prefix + constants.STATE_DICT_NAME]) + + artifact_log_priors = saved[prefix + constants.ARTIFACT_LOG_PRIORS_NAME] # possibly None + artifact_spectra_state_dict = saved[prefix + constants.ARTIFACT_SPECTRA_STATE_DICT_NAME] # possibly None + return model, artifact_log_priors, artifact_spectra_state_dict + + +# log artifact priors and artifact spectra may be None +def load_artifact_model(path, device, prefix: str = "artifact") -> ArtifactModel: + saved = torch.load(path, map_location=device) + return artifact_model_from_saved_dict(saved, prefix) + + +def load_base_model_and_artifact_model(path, device) -> ArtifactModel: + saved = torch.load(path, map_location=device) + base_model = base_model_from_saved_dict(saved, prefix="base") + artifact_model, artifact_log_priors, artifact_spectra = artifact_model_from_saved_dict(saved, prefix="artifact") + return base_model, artifact_model, artifact_log_priors, artifact_spectra + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/artifact_spectra.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/artifact_spectra.py new file mode 100644 index 00000000000..c59f66b377b --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/artifact_spectra.py @@ -0,0 +1,131 @@ +import math + +import torch +from permutect import utils +from torch import nn, IntTensor + +from permutect.metrics.plotting import simple_plot +from permutect.utils import beta_binomial, Variation + + +class ArtifactSpectra(nn.Module): + """ + This model takes in 1D tensors (batch size, ) of alt counts and depths and a 2D tensor (batch size, len(Variation)) + of one-hot-encoded variant types and computes the log likelihoods + log P(alt count | depth, variant type, spectrum parameters). + + The probability P(alt count | depth) is a K-component beta binomial mixture model where each component is a beta binomial + P_k(a|d) = integral{Beta(f|alpha, beta) * Binom(a|d, f) df}. + + This integral is exact and is implemented in utils.beta_binomial() + + Importantly, the beta shape parameter is *not* learnable. We fix it at a high constant value in order to force the spectrum + to fall off very rapidly after its peak allele fraction. Otherwise, the unrealistically long tail gives artificially + high likelihoods for artifacts at high allele fractions. + """ + + def __init__(self, num_components: int): + super(ArtifactSpectra, self).__init__() + self.beta = 100 # not a learnable parameter! + self.K = num_components + self.V = len(Variation) + + self.weights0_pre_softmax_vk = torch.nn.Parameter(torch.ones(self.V, self.K)) + + # for each component and variant type: + # weight_pre_softmax = weight0_pre_softmax + gamma * sigmoid(depth * kappa) + self.gamma_vk = torch.nn.Parameter(0.1 * torch.rand(self.V, self.K)) + self.kappa_vk = torch.nn.Parameter(0.02 * torch.ones(self.V, self.K)) + + # initialize evenly spaced alphas from 1 to 7 for each variant type + self.alpha0_pre_exp_vk = torch.nn.Parameter(torch.log(1 + 7 * (torch.arange(self.K) / self.K)).repeat(self.V, 1)) + + self.eta_pre_exp_vk = torch.nn.Parameter(torch.ones(self.V, self.K)) + + self.delta_pre_exp_vk = torch.nn.Parameter(torch.log(torch.ones(self.V, self.K)/50)) + + # for each component and variant type: + # alpha = exp(alpha0_pre_exp - exp(eta_pre_exp)*sigmoid(depth * exp(delta_pre_exp))) + + ''' + here x is a 2D tensor, 1st dimension batch, 2nd dimension being features that determine which Beta mixture to use + n and k are 1D tensors, the only dimension being batch. + ''' + def forward(self, variant_types_b: torch.IntTensor, depths_b, alt_counts_b): + # indexing convention: b is batch, v is variant type, k is cluster component + alt_counts_bk = alt_counts_b[:, None] + depths_bk = depths_b[:, None] + + eta_vk = torch.exp(self.eta_pre_exp_vk) + delta_vk = torch.exp(self.delta_pre_exp_vk) + + var_types_b = variant_types_b.long() + alpha0_pre_exp_bk = self.alpha0_pre_exp_vk[var_types_b, :] + delta_bk = delta_vk[var_types_b, :] + eta_bk = eta_vk[var_types_b, :] + alpha_bk = torch.exp(alpha0_pre_exp_bk - eta_bk * torch.sigmoid(depths_bk * delta_bk)) + beta_bk = self.beta * torch.ones_like(alpha_bk) + beta_binomial_likelihoods_bk = beta_binomial(depths_bk, alt_counts_bk, alpha_bk, beta_bk) + + if alpha_bk.isnan().any(): + print("NaN found in alpha_bk") + assert 1 < 0, "FAIL" + + weights0_pre_softmax_bk = self.weights0_pre_softmax_vk[var_types_b, :] + gamma_bk = self.gamma_vk[var_types_b, :] + kappa_bk = self.kappa_vk[var_types_b, :] + weights_pre_softmax_bk = weights0_pre_softmax_bk + gamma_bk * torch.sigmoid(depths_bk * kappa_bk) + log_weights_bk = torch.log_softmax(weights_pre_softmax_bk, dim=-1) # softmax over component dimension + + weighted_likelihoods_bk = log_weights_bk + beta_binomial_likelihoods_bk + result_b = torch.logsumexp(weighted_likelihoods_bk, dim=-1, keepdim=False) + return result_b + + # TODO: utter code duplication with somatic spectrum + def fit(self, num_epochs, types_b: IntTensor, depths_1d_tensor, alt_counts_1d_tensor, batch_size=64): + optimizer = torch.optim.Adam(self.parameters()) + num_batches = math.ceil(len(alt_counts_1d_tensor) / batch_size) + + for epoch in range(num_epochs): + for batch in range(num_batches): + batch_start = batch * batch_size + batch_end = min(batch_start + batch_size, len(alt_counts_1d_tensor)) + batch_slice = slice(batch_start, batch_end) + loss = -torch.mean(self.forward(types_b[batch_slice], depths_1d_tensor[batch_slice], alt_counts_1d_tensor[batch_slice])) + utils.backpropagate(optimizer, loss) + + ''' + get raw data for a spectrum plot of probability density vs allele fraction for a particular variant type + ''' + def spectrum_density_vs_fraction(self, variant_type: Variation, depth: int): + fractions_f = torch.arange(0.01, 0.99, 0.001) # 1D tensor + + weights0_pre_softmax_k = self.weights0_pre_softmax_vk[variant_type] + gamma_k = self.gamma_vk[variant_type] + kappa_k = self.kappa_vk[variant_type] + alpha0_pre_exp_k = self.alpha0_pre_exp_vk[variant_type] + eta_k = torch.exp(self.eta_pre_exp_vk[variant_type]) + delta_k = torch.exp(self.delta_pre_exp_vk[variant_type]) + + alpha_k = torch.exp(alpha0_pre_exp_k - eta_k * torch.sigmoid(depth * delta_k)) + weights_pre_softmax_k = weights0_pre_softmax_k + gamma_k * torch.sigmoid(depth * kappa_k) + + log_weights_k = torch.log_softmax(weights_pre_softmax_k, dim=0) + beta_k = self.beta * torch.ones_like(alpha_k) + + # distribution done on CPU + dist = torch.distributions.beta.Beta(alpha_k.cpu(), beta_k.cpu()) + log_densities_fk = dist.log_prob(fractions_f.unsqueeze(dim=-1)) + log_weights_fk = log_weights_k.unsqueeze(dim=0).cpu() + + log_weighted_densities_fk = log_weights_fk + log_densities_fk + densities_f = torch.exp(torch.logsumexp(log_weighted_densities_fk, dim=1, keepdim=False)) + + return fractions_f, densities_f + + ''' + here x is a 1D tensor, a single datum/row of the 2D tensors as above + ''' + def plot_spectrum(self, variant_type: Variation, title, depth: int): + fractions, densities = self.spectrum_density_vs_fraction(variant_type, depth) + return simple_plot([(fractions.numpy(), densities.numpy(), " ")], "AF", "density", title) diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/base_model.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/base_model.py new file mode 100644 index 00000000000..556d6c9dbe9 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/base_model.py @@ -0,0 +1,618 @@ +import math +from abc import ABC, abstractmethod +from enum import Enum +from itertools import chain +import time +from typing import List + +import psutil +import torch +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from torch.nn.parameter import Parameter +from tqdm.autonotebook import trange, tqdm + +from permutect import utils, constants +from permutect.architecture.dna_sequence_convolution import DNASequenceConvolution +from permutect.architecture.gated_mlp import GatedMLP, GatedRefAltMLP +from permutect.architecture.gradient_reversal.module import GradientReversal +from permutect.architecture.mlp import MLP +from permutect.data.base_datum import BaseBatch, DEFAULT_GPU_FLOAT, DEFAULT_CPU_FLOAT +from permutect.data.base_dataset import BaseDataset, ALL_COUNTS_SENTINEL +from permutect.metrics.evaluation_metrics import LossMetrics, EmbeddingMetrics, round_up_to_nearest_three, MAX_COUNT +from permutect.parameters import BaseModelParameters, TrainingParameters + + +# group rows into consecutive chunks to yield a 3D tensor, average over dim=1 to get +# 2D tensor of sums within each chunk +from permutect.utils import Variation, Label + + +def sums_over_chunks(tensor2d: torch.Tensor, chunk_size: int): + assert len(tensor2d) % chunk_size == 0 + return torch.sum(tensor2d.reshape([len(tensor2d) // chunk_size, chunk_size, -1]), dim=1) + + +# note: this works for both BaseBatch/BaseDataset AND ArtifactBatch/ArtifactDataset +# if by_count is True, each count is weighted separately for balanced loss within that count +def calculate_batch_weights(batch, dataset, by_count: bool): + # TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss + # For batch index n, we want weight[n] = dataset.weights[alt_counts[n], labels[n], variant_types[n]] + counts = batch.get_alt_counts() + labels = batch.get_labels() + variant_types = batch.get_variant_types() + + return utils.index_3d_array(dataset.weights, counts, labels, variant_types) if by_count else \ + utils.index_2d_array(dataset.weights[ALL_COUNTS_SENTINEL], labels, variant_types) + + +# note: this works for both BaseBatch/BaseDataset AND ArtifactBatch/ArtifactDataset +# if by_count is True, each count is weighted separately for balanced loss within that count +def calculate_batch_source_weights(batch, dataset, by_count: bool): + # For batch index n, we want weight[n] = dataset.source_weights[alt_counts[n], sources[n], variant_types[n]] + counts = batch.get_alt_counts() + sources = batch.get_sources() + variant_types = batch.get_variant_types() + + return utils.index_3d_array(dataset.source_weights, counts, sources, variant_types) if by_count else \ + utils.index_2d_array(dataset.source_weights[ALL_COUNTS_SENTINEL], sources, variant_types) + + +class LearningMethod(Enum): + # train the embedding by minimizing cross-entropy loss of binary predictor on labeled data + SUPERVISED = "SUPERVISED" + + # same but use entropy regularization loss on unlabeled data + SEMISUPERVISED = "SEMISUPERVISED" + + # TODO: IMPLEMENT THIS + # optimize a clustering model with center triplet loss + SUPERVISED_CLUSTERING = "SUPERVISED_CLUSTERING" + + # TODO: IMPLEMENT THIS + # modify data via a finite set of affine transformations and train the embedding to recognize which was applied + AFFINE_TRANSFORMATION = "AFFINE" + + # modify data via a finite set of affine transformations and train the embedding to recognize which was applied + MASK_PREDICTION = "MASK_PREDICTION" + + AUTOENCODER = "AUTOENCODER" + + DEEPSAD = "DEEPSAD" + + MARS = "MARS" + + +def make_gated_ref_alt_mlp_encoder(input_dimension: int, params: BaseModelParameters): + return GatedRefAltMLP(d_model=input_dimension, d_ffn=params.self_attention_hidden_dimension, num_blocks=params.num_self_attention_layers) + + +class BaseModel(torch.nn.Module): + """ + DeepSets framework for reads and variant info. We embed each read and concatenate the mean ref read + embedding, mean alt read embedding, and variant info embedding, then apply an aggregation function to + this concatenation to obtain an embedding / representation of the read set for downstream use such as + variant filtering and clustering. + + hidden_read_layers: dimensions of layers for embedding reads, excluding input dimension, which is the + size of each read's 1D tensor + + hidden_info_layers: dimensions of layers for embedding variant info, excluding input dimension, which is the + size of variant info 1D tensor + + aggregation_layers: dimensions of layers for aggregation, excluding its input which is determined by the + read and info embeddings. + + output_layers: dimensions of layers after aggregation, excluding the output dimension, + which is 1 for a single logit representing artifact/non-artifact. This is not part of the aggregation layers + because we have different output layers for each variant type. + """ + + def __init__(self, params: BaseModelParameters, num_read_features: int, num_info_features: int, ref_sequence_length: int, device=utils.gpu_if_available()): + super(BaseModel, self).__init__() + + self._device = device + self._dtype = DEFAULT_GPU_FLOAT if device != torch.device("cpu") else DEFAULT_CPU_FLOAT + self._ref_sequence_length = ref_sequence_length + self._params = params + + # embeddings of reads, info, and reference sequence prior to the transformer layers + self.read_embedding = MLP([num_read_features] + params.read_layers, batch_normalize=params.batch_normalize, dropout_p=params.dropout_p) + self.info_embedding = MLP([num_info_features] + params.info_layers, batch_normalize=params.batch_normalize, dropout_p=params.dropout_p) + self.ref_seq_cnn = DNASequenceConvolution(params.ref_seq_layer_strings, ref_sequence_length) + + embedding_dim = self.read_embedding.output_dimension() + self.info_embedding.output_dimension() + self.ref_seq_cnn.output_dimension() + + self.ref_alt_reads_encoder = make_gated_ref_alt_mlp_encoder(embedding_dim, params) + + # after encoding alt reads (along with info and ref seq embeddings and with self-attention to ref reads) + # pass through another MLP + self.aggregation = MLP([embedding_dim] + params.aggregation_layers, batch_normalize=params.batch_normalize, dropout_p=params.dropout_p) + + self.to(device=self._device, dtype=self._dtype) + + def output_dimension(self) -> int: + return self.aggregation.output_dimension() + + def ref_alt_seq_embedding_dimension(self) -> int: + return self.ref_seq_cnn.output_dimension() + + def ref_sequence_length(self) -> int: + return self._ref_sequence_length + + def set_epoch_type(self, epoch_type: utils.Epoch): + if epoch_type == utils.Epoch.TRAIN: + self.train(True) + utils.unfreeze(self.parameters()) + else: + self.train(False) + utils.freeze(self.parameters()) + + # I really don't like the forward method of torch.nn.Module with its implicit calling that PyCharm doesn't recognize + def forward(self, batch: BaseBatch): + pass + + # here 'v' means "variant index within a batch", 'r' means "read index within a variant or the batch", 'e' means "index within an embedding" + # so, for example, "re" means a 2D tensor with all reads in the batch stacked and "vre" means a 3D tensor indexed + # first by variant within the batch, then the read + def calculate_representations(self, batch: BaseBatch, weight_range: float = 0) -> torch.Tensor: + ref_counts, alt_counts = batch.get_ref_counts(), batch.get_alt_counts() + total_ref, total_alt = torch.sum(ref_counts).item(), torch.sum(alt_counts).item() + + read_embeddings_re = self.read_embedding.forward(batch.get_reads_2d().to(dtype=self._dtype)) + info_embeddings_ve = self.info_embedding.forward(batch.get_info_2d().to(dtype=self._dtype)) + ref_seq_embeddings_ve = self.ref_seq_cnn(batch.get_ref_sequences_2d().to(dtype=self._dtype)) + info_and_seq_ve = torch.hstack((info_embeddings_ve, ref_seq_embeddings_ve)) + info_and_seq_re = torch.vstack((torch.repeat_interleave(info_and_seq_ve, repeats=ref_counts, dim=0), + torch.repeat_interleave(info_and_seq_ve, repeats=alt_counts, dim=0))) + reads_info_seq_re = torch.hstack((read_embeddings_re, info_and_seq_re)) + + # TODO: might be a bug if every datum in batch has zero ref reads? + ref_reads_info_seq_re = reads_info_seq_re[:total_ref] + alt_reads_info_seq_re = reads_info_seq_re[total_ref:] + + # TODO: make sure it handles ref count = 0 case + transformed_ref_re, transformed_alt_re = self.ref_alt_reads_encoder.forward(ref_reads_info_seq_re, alt_reads_info_seq_re, ref_counts, alt_counts) + + alt_weights_r = 1 + weight_range * (1 - 2 * torch.rand(total_alt, device=self._device, dtype=self._dtype)) + + # normalize so read weights within each variant sum to 1 + alt_wt_sums_v = utils.sums_over_rows(alt_weights_r, alt_counts) + normalized_alt_weights_r = alt_weights_r / torch.repeat_interleave(alt_wt_sums_v, repeats=alt_counts, dim=0) + + alt_means_ve = utils.sums_over_rows(transformed_alt_re * normalized_alt_weights_r[:,None], alt_counts) + + result_ve = self.aggregation.forward(alt_means_ve) + + return result_ve, ref_seq_embeddings_ve # ref seq embeddings are useful later + + def make_dict_for_saving(self, prefix: str = ""): + return {(prefix + constants.STATE_DICT_NAME): self.state_dict(), + (prefix + constants.HYPERPARAMS_NAME): self._params, + (prefix + constants.NUM_READ_FEATURES_NAME): self.read_embedding.input_dimension(), + (prefix + constants.NUM_INFO_FEATURES_NAME): self.info_embedding.input_dimension(), + (prefix + constants.REF_SEQUENCE_LENGTH_NAME): self.ref_sequence_length()} + + def save(self, path): + torch.save(self.make_dict_for_saving(), path) + + +def base_model_from_saved_dict(saved, prefix: str = "", device: torch.device = utils.gpu_if_available()): + hyperparams = saved[prefix + constants.HYPERPARAMS_NAME] + num_read_features = saved[prefix + constants.NUM_READ_FEATURES_NAME] + num_info_features = saved[prefix + constants.NUM_INFO_FEATURES_NAME] + ref_sequence_length = saved[prefix + constants.REF_SEQUENCE_LENGTH_NAME] + + model = BaseModel(hyperparams, num_read_features=num_read_features, num_info_features=num_info_features, + ref_sequence_length=ref_sequence_length, device=device) + model.load_state_dict(saved[prefix + constants.STATE_DICT_NAME]) + + # in case the state dict had the wrong dtype for the device we're on now eg base model was pretrained on GPU + # and we're now on CPU + model.to(model._dtype) + + return model + + +def load_base_model(path, prefix: str = "", device: torch.device = utils.gpu_if_available()) -> BaseModel: + saved = torch.load(path, map_location=device) + return base_model_from_saved_dict(saved, prefix, device) + + +# outputs a 1D tensor of losses over the batch. We assume it needs the representations of the batch data from the base +# model. We nonetheless also use the model as an input because there are some learning strategies that involve +# computing representations of a modified batch. +class BaseModelLearningStrategy(ABC): + @abstractmethod + def loss_function(self, base_model: BaseModel, base_batch: BaseBatch, base_model_representations: torch.Tensor): + pass + + +class BaseModelSemiSupervisedLoss(torch.nn.Module, BaseModelLearningStrategy): + def __init__(self, input_dim: int, hidden_top_layers: List[int], params: BaseModelParameters): + super(BaseModelSemiSupervisedLoss, self).__init__() + + self.bce = torch.nn.BCEWithLogitsLoss(reduction='none') # no reduction because we may want to first multiply by weights for unbalanced data + + # go from base model output representation to artifact logit for supervised loss + self.logit_predictor = MLP([input_dim] + hidden_top_layers + [1], batch_normalize=params.batch_normalize, dropout_p=params.dropout_p) + + def loss_function(self, base_model: BaseModel, base_batch: BaseBatch, base_model_representations: torch.Tensor): + logits = self.logit_predictor.forward(base_model_representations).reshape((base_batch.size())) + labels = base_batch.get_training_labels() + + # base batch always has labels, but for unlabeled elements these labels are meaningless and is_labeled_mask is zero + cross_entropies = self.bce(logits, labels) + probabilities = torch.sigmoid(logits) + entropies = self.bce(logits, probabilities) + + return base_batch.get_is_labeled_mask() * cross_entropies + (1 - base_batch.get_is_labeled_mask()) * entropies + + # I don't like implicit forward!! + def forward(self): + pass + + +def permute_columns_independently(mat: torch.Tensor): + assert mat.dim() == 2 + num_rows, num_cols = mat.size() + weights = torch.ones(num_rows) + + result = torch.clone(mat) + for col in range(num_cols): + idx = torch.multinomial(weights, num_rows, replacement=True) + result[:, col] = result[:, col][idx] + return result + + +# randomly choose read features to "mask" -- where a masked feature is permuted randomly over all the reads in the batch. +# this essentially means drawing masked features from the empirical marginal distribution +# the pretext self-supervision task is, for each datum, to predict which features were masked +# note that this basically means destroy correlations for a random selection of features +class BaseModelMaskPredictionLoss(torch.nn.Module, BaseModelLearningStrategy): + def __init__(self, num_read_features: int, base_model_output_dim: int, hidden_top_layers: List[int], params: BaseModelParameters): + super(BaseModelMaskPredictionLoss, self).__init__() + + self.num_read_features = num_read_features + + self.bce = torch.nn.BCEWithLogitsLoss(reduction='none') # no reduction because we may want to first multiply by weights for unbalanced data + + # go from base model output representation to artifact logit for supervised loss + self.mask_predictor = MLP([base_model_output_dim] + hidden_top_layers + [num_read_features], batch_normalize=params.batch_normalize, dropout_p=params.dropout_p) + + def loss_function(self, base_model: BaseModel, base_batch: BaseBatch, base_model_representations): + # TODO: this is broken now that batches have mixed counts + '''ref_count, alt_count = base_batch.ref_count, base_batch.alt_count + total_ref, total_alt = ref_count * base_batch.size(), alt_count * base_batch.size() + + alt_reads_2d = base_batch.get_reads_2d()[total_ref:] + permuted_reads = permute_columns_independently(base_batch.get_reads_2d()) + permuted_alt_reads = permuted_reads[:total_alt] + + datum_mask = torch.bernoulli(0.1 * torch.ones(base_batch.size(), self.num_read_features)) + + # each read within a datum gets the same mask + reads_mask = torch.repeat_interleave(datum_mask, repeats=alt_count, dim=0) + + original_reads_2d = base_batch.reads_2d + modified_alt_reads = alt_reads_2d * (1 - reads_mask) + permuted_alt_reads * reads_mask + base_batch.reads_2d = torch.vstack((original_reads_2d[:total_ref], modified_alt_reads)) + + # TODO: is there any reason to fix the batch with base_batch.reads_2d = original_reads_2d? + + # shape is batch size x num read features, each entry being a logit for "was this feature masked in this datum?" + mask_prediction_logits = self.mask_predictor.forward(base_model_representations) + + # by batch index and feature + losses_bf = self.bce(mask_prediction_logits, datum_mask) + return torch.mean(losses_bf, dim=1) # average over read features + ''' + pass + + # I don't like implicit forward!! + def forward(self): + pass + + +# chamfer distance between two 3D tensors B x N1 x E and B x N2 x E, where B is the batch size, N1/2 are the number +# of items in the two sets, and E is the dimensionality of each item +# returns a 1D tensor of length B +def chamfer_distance(set1_bne, set2_bne): + diffs_bnne = torch.unsqueeze(set1_bne, dim=2) - torch.unsqueeze(set2_bne, dim=1) + l1_dists_bnn = torch.mean(torch.abs(diffs_bnne), dim=-1) + + chamfer_dists12_bn = torch.min(l1_dists_bnn, dim=-2).values + chamfer_dists21_bn = torch.min(l1_dists_bnn, dim=-1).values + symmetric_chamfer_b = torch.mean(chamfer_dists12_bn, dim=-1) + torch.mean(chamfer_dists21_bn, dim=-1) + return symmetric_chamfer_b + + +# self-supervision approach where we use the base model embedding to regenerate the set and use Chamfer distance as the +# reconstruction error. We regenerate the set via the Transformer Set Prediction Network approach of Kosiorek et al -- seed a set +# of N reads by concatenated the embedding with N random vectors, then map it so the final reconstructed set with transformers. +class BaseModelAutoencoderLoss(torch.nn.Module, BaseModelLearningStrategy): + def __init__(self, read_dim: int, hidden_top_layers: List[int], params: BaseModelParameters): + super(BaseModelAutoencoderLoss, self).__init__() + self.base_model_output_dimension = params.output_dimension() + + # TODO: explore making random seed dimension different from the base model embedding dimension + self.random_seed_dimension = self.base_model_output_dimension + self.transformer_dimension = self.base_model_output_dimension + self.random_seed_dimension + + # TODO: maybe also a parameter to scale the random vectors? + + # TODO: should these decoder params be the same as the base model encoder params? It seems reasonable. + + # TODO: this is broken -- use the ref_alt_encoder + #self.alt_decoder = make_gated_mlp_encoder(self.transformer_dimension, params) + #self.ref_decoder = make_gated_mlp_encoder(self.transformer_dimension, params) + + self.mapping_back_to_reads = MLP([self.transformer_dimension] + hidden_top_layers + [read_dim]) + + def loss_function(self, base_model: BaseModel, base_batch: BaseBatch, base_model_representations): + # TODO: this is broken now that batches have mixed counts + '''var_count, alt_count, ref_count = base_batch.size(), base_batch.alt_count, base_batch.ref_count + + total_ref, total_alt = ref_count * var_count, alt_count * var_count + + representations_ve = base_model_representations + random_alt_seeds_vre = torch.randn(var_count, alt_count, self.random_seed_dimension) + random_ref_seeds_vre = torch.randn(var_count, ref_count, self.random_seed_dimension) if ref_count > 0 else None + alt_representations_vre = torch.unsqueeze(representations_ve, dim=1).expand(-1, alt_count, -1) # repeat over the dummy read index + ref_representations_vre = torch.unsqueeze(representations_ve, dim=1).expand(-1, ref_count, -1) + + alt_vre = torch.cat((alt_representations_vre, random_alt_seeds_vre), dim=-1) + ref_vre = torch.cat((ref_representations_vre, random_ref_seeds_vre), dim=-1) if ref_count > 0 else None + + # TODO: update these to reflect mixed-count batches. Gated MLPs now take inputs flattened over batch dimension + # TODO: and have an extra input of ref and alt read counts + decoded_alt_vre = self.alt_decoder.forward(alt_vre) + decoded_ref_vre = self.ref_decoder.forward(ref_vre) if ref_count > 0 else None + + decoded_alt_re = torch.reshape(decoded_alt_vre, (var_count * alt_count, -1)) + decoded_ref_re = torch.reshape(decoded_ref_vre, (var_count * ref_count, -1)) if ref_count > 0 else None + + # the raw read tensors are quantile normalized with Gaussian output + reconstructed_alt_vre = torch.reshape(self.mapping_back_to_reads(decoded_alt_re),(var_count, alt_count, -1)) + reconstructed_ref_vre = torch.reshape(self.mapping_back_to_reads(decoded_ref_re), (var_count, ref_count, -1)) if ref_count > 0 else None + + original_alt_vre = base_batch.get_reads_2d()[total_ref:].reshape(var_count, alt_count, -1) + original_ref_vre = base_batch.get_reads_2d()[:total_ref].reshape(var_count, ref_count, -1) if ref_count > 0 else None + + alt_chamfer_dist = chamfer_distance(original_alt_vre, reconstructed_alt_vre) + ref_chamfer_dist = chamfer_distance(original_ref_vre, reconstructed_ref_vre) if ref_count > 0 else 0 + return alt_chamfer_dist + ref_chamfer_dist + ''' + pass + + # I don't like implicit forward!! + def forward(self): + pass + + +class BaseModelDeepSADLoss(torch.nn.Module, BaseModelLearningStrategy): + def __init__(self, embedding_dim: int): + super(BaseModelDeepSADLoss, self).__init__() + self.embedding_dim = embedding_dim + self.normal_centroid = Parameter(torch.zeros(embedding_dim)) + + # normal embeddings should cluster near the origin and artifact embeddings should be far + def loss_function(self, base_model: BaseModel, base_batch: BaseBatch, base_model_representations): + dist_squared = torch.square(torch.norm(base_model_representations - self.normal_centroid, dim=1)) + + # labels are 1 for artifact, 0 otherwise. We convert to +1 if normal, -1 if artifact + # DeepSAD assumes most unlabeled data are normal and so the unlabeled loss is identical to the normal loss, that is, + # squared Euclidean distance from the centroid + signs = (1 - 2 * base_batch.get_training_labels()) * base_batch.get_is_labeled_mask() + 1 * (1 - base_batch.get_is_labeled_mask()) + + # distance squared for normal and unlabeled, inverse distance squared for artifact + return dist_squared ** signs + + # I don't like implicit forward!! + def forward(self): + pass + + +class BaseModelMARSLoss(torch.nn.Module, BaseModelLearningStrategy): + def __init__(self, embedding_dim: int): + super(BaseModelMARSLoss, self).__init__() + self.embedding_dim = embedding_dim + # TODO: magic constants!!!!! + self.num_normal_clusters = 3 + self.num_artifact_clusters = 5 + + # weight of centroid-centroid loss vs embedding-centroid loss + self.tau = 0.2 + + # ce denotes indexing by cluster, then embedding dimension + self.centroids_ce = Parameter(torch.zeros((self.num_normal_clusters + self.num_artifact_clusters, embedding_dim))) + + # normal embeddings should cluster near the origin and artifact embeddings should be far + def loss_function(self, base_model: BaseModel, base_batch: BaseBatch, base_model_representations): + embeddings_be = base_model_representations + seps_bce = torch.unsqueeze(embeddings_be, dim=1) - torch.unsqueeze(self.centroids_ce, dim=0) + dist_squared_bc = torch.square(torch.norm(seps_bce, dim=-1)) + + normal_dist_squared_b = torch.min(dist_squared_bc[:, :self.num_normal_clusters], dim=-1).values + artifact_dist_squared_b = torch.min(dist_squared_bc[:, self.num_normal_clusters:], dim=-1).values + min_dist_squared_b = torch.min(dist_squared_bc, dim=-1).values + + # closest centroid with correct label is labeled, otherwise just the closest centroid + labeled_losses_b = (base_batch.get_training_labels() * artifact_dist_squared_b + (1 - base_batch.get_training_labels()) * normal_dist_squared_b) + unlabeled_losses_b = min_dist_squared_b + embedding_centroid_losses_b = base_batch.get_is_labeled_mask() * labeled_losses_b + (1 - base_batch.get_is_labeled_mask()) * unlabeled_losses_b + + # average distance between centroids + centroid_seps_cce = torch.unsqueeze(self.centroids_ce, dim=0) - torch.unsqueeze(self.centroids_ce, dim=1) + centroid_dist_squared_cc = torch.square(torch.norm(centroid_seps_cce, dim=-1)) + centroid_centroid_loss = torch.mean(centroid_dist_squared_cc) + + # TODO: need to control arbitrarily negative loss achieved by making an unused centroid go far away from the others + # note that broadcasting the subtraction means the centroid-centroid loss is repeated once for each datum in the batch + return embedding_centroid_losses_b - self.tau * centroid_centroid_loss + + # I don't like implicit forward!! + def forward(self): + pass + + +# artifact model parameters are for simultaneously training an artifact model on top of the base model +# to measure quality, especially in unsupervised training when the loss metric isn't directly related to accuracy or cross-entropy +def learn_base_model(base_model: BaseModel, dataset: BaseDataset, learning_method: LearningMethod, training_params: TrainingParameters, + summary_writer: SummaryWriter, validation_fold: int = None): + print(f"Memory usage percent: {psutil.virtual_memory().percent:.1f}") + is_cuda = base_model._device.type == 'cuda' + print(f"Is CUDA available? {is_cuda}") + + for idx, variation_type in enumerate(utils.Variation): + print(f"For variation type {variation_type.name}, there are {int(dataset.totals[ALL_COUNTS_SENTINEL][Label.ARTIFACT][idx].item())} \ + artifacts, {int(dataset.totals[ALL_COUNTS_SENTINEL][Label.VARIANT][idx].item())} \ + non-artifacts, and {int(dataset.totals[ALL_COUNTS_SENTINEL][Label.UNLABELED][idx].item())} unlabeled data.") + + # TODO: use Python's match syntax, but this requires updating Python version in the docker + # TODO: hidden_top_layers are hard-coded! + if learning_method == LearningMethod.SUPERVISED or learning_method == LearningMethod.SEMISUPERVISED: + learning_strategy = BaseModelSemiSupervisedLoss(input_dim=base_model.output_dimension(), hidden_top_layers=[30,-1,-1,-1,10], params=base_model._params) + elif learning_method == LearningMethod.MASK_PREDICTION: + learning_strategy = BaseModelMaskPredictionLoss(num_read_features=dataset.num_read_features, + base_model_output_dim=base_model.output_dimension(), hidden_top_layers=[10,10,10], params=base_model._params) + elif learning_method == LearningMethod.AUTOENCODER: + learning_strategy = BaseModelAutoencoderLoss(read_dim=dataset.num_read_features, hidden_top_layers=[20,20,20], params=base_model._params) + elif learning_method == LearningMethod.DEEPSAD: + learning_strategy = BaseModelDeepSADLoss(embedding_dim=base_model.output_dimension()) + elif learning_method == LearningMethod.MARS: + learning_strategy = BaseModelMARSLoss(embedding_dim=base_model.output_dimension()) + else: + raise Exception("not implemented yet") + learning_strategy.to(device=base_model._device, dtype=base_model._dtype) + + # adversarial loss to learn features that forget the alt count + alt_count_gradient_reversal = GradientReversal(alpha=0.01) #initialize as barely active + alt_count_predictor = MLP([base_model.output_dimension()] + [30, -1, -1, -1, 1]).to(device=base_model._device, dtype=base_model._dtype) + alt_count_loss_func = torch.nn.MSELoss(reduction='none') + alt_count_adversarial_metrics = LossMetrics() + + # TODO: fused = is_cuda? + train_optimizer = torch.optim.AdamW(chain(base_model.parameters(), learning_strategy.parameters(), alt_count_predictor.parameters()), + lr=training_params.learning_rate, weight_decay=training_params.weight_decay) + # train scheduler needs to be given the thing that's supposed to decrease at the end of each epoch + train_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + train_optimizer, factor=0.2, patience=5, threshold=0.001, min_lr=(training_params.learning_rate/100), verbose=True) + + classifier_on_top = MLP([base_model.output_dimension()] + [30, -1, -1, -1, 10] + [1])\ + .to(device=base_model._device, dtype=base_model._dtype) + classifier_bce = torch.nn.BCEWithLogitsLoss(reduction='none') + + classifier_optimizer = torch.optim.AdamW(classifier_on_top.parameters(), + lr=training_params.learning_rate, + weight_decay=training_params.weight_decay, + fused=is_cuda) + classifier_metrics = LossMetrics() + + validation_fold_to_use = (dataset.num_folds - 1) if validation_fold is None else validation_fold + train_loader = dataset.make_data_loader(dataset.all_but_one_fold(validation_fold_to_use), training_params.batch_size, is_cuda, training_params.num_workers) + valid_loader = dataset.make_data_loader([validation_fold_to_use], training_params.batch_size, is_cuda, training_params.num_workers) + + for epoch in trange(1, training_params.num_epochs + 1, desc="Epoch"): + p = epoch - 1 + new_alpha = (2/(1 + math.exp(-0.1*p))) - 1 + alt_count_gradient_reversal.set_alpha(new_alpha) # alpha increases linearly + start_epoch = time.time() + print(f"Start of epoch {epoch}, memory usage percent: {psutil.virtual_memory().percent:.1f}") + for epoch_type in (utils.Epoch.TRAIN, utils.Epoch.VALID): + base_model.set_epoch_type(epoch_type) + loss_metrics = LossMetrics() + + loader = train_loader if epoch_type == utils.Epoch.TRAIN else valid_loader + loader_iter = iter(loader) + + next_batch_cpu = next(loader_iter) + next_batch = next_batch_cpu.copy_to(base_model._device, non_blocking=is_cuda) + + pbar = tqdm(range(len(loader)), mininterval=60) + for n in pbar: + batch_cpu = next_batch_cpu + batch = next_batch + + # Optimization: Asynchronously send the next batch to the device while the model does work + next_batch_cpu = next(loader_iter) + next_batch = next_batch_cpu.copy_to(base_model._device, non_blocking=is_cuda) + + # TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss + weights = calculate_batch_weights(batch_cpu, dataset, by_count=True) + weights = weights.to(device=base_model._device, dtype=base_model._dtype, non_blocking=True) + + # unused output is the embedding of ref and alt alleles with context + representations, _ = base_model.calculate_representations(batch, weight_range=base_model._params.reweighting_range) + losses = learning_strategy.loss_function(base_model, batch, representations) + + if losses is None: + continue + + loss_metrics.record_losses(losses.detach(), batch, weights) + + # gradient reversal means parameters before the representation try to maximize alt count prediction loss, i.e. features + # try to forget alt count, while parameters after the representation try to minimize it, i.e. they try + # to achieve the adversarial task + alt_count_pred = torch.sigmoid(alt_count_predictor.forward(alt_count_gradient_reversal(representations)).squeeze()) + alt_count_target = batch.get_alt_counts().to(dtype=alt_count_pred.dtype)/20 + alt_count_losses = alt_count_loss_func(alt_count_pred, alt_count_target) + + alt_count_adversarial_metrics.record_losses(alt_count_losses.detach(), batch, weights=torch.ones_like(alt_count_losses)) + + loss = torch.sum((weights * losses) + alt_count_losses) + + classification_logits = classifier_on_top.forward(representations.detach()).reshape(batch.size()) + classification_losses = classifier_bce(classification_logits, batch.get_training_labels()) + classification_loss = torch.sum(batch.get_is_labeled_mask() * weights * classification_losses) + classifier_metrics.record_losses(classification_losses.detach(), batch, batch.get_is_labeled_mask() * weights) + + if epoch_type == utils.Epoch.TRAIN: + utils.backpropagate(train_optimizer, loss) + utils.backpropagate(classifier_optimizer, classification_loss) + + # done with one epoch type -- training or validation -- for this epoch + loss_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer) + classifier_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer, prefix="auxiliary-classifier-") + alt_count_adversarial_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer, prefix="alt-count-adversarial-predictor") + + if epoch_type == utils.Epoch.TRAIN: + train_scheduler.step(loss_metrics.get_labeled_loss()) + + print(f"Labeled base model loss for {epoch_type.name} epoch {epoch}: {loss_metrics.get_labeled_loss():.3f}") + print(f"Labeled auxiliary classifier loss for {epoch_type.name} epoch {epoch}: {classifier_metrics.get_labeled_loss():.3f}") + print(f"Alt count adversarial loss for {epoch_type.name} epoch {epoch}: {alt_count_adversarial_metrics.get_labeled_loss():.3f}") + print(f"End of epoch {epoch}, memory usage percent: {psutil.virtual_memory().percent:.1f}, time elapsed(s): {time.time() - start_epoch:.2f}") + # done with training and validation for this epoch + # note that we have not learned the AF spectrum yet + # done with training + + record_embeddings(base_model, train_loader, summary_writer) + + +# after training for visualizing clustering etc of base model embeddings +def record_embeddings(base_model: BaseModel, loader, summary_writer: SummaryWriter): + # base_model.freeze_all() whoops -- it doesn't have freeze_all + embedding_metrics = EmbeddingMetrics() + ref_alt_seq_metrics = EmbeddingMetrics() + + pbar = tqdm(enumerate(loader), mininterval=60) + for n, batch_cpu in pbar: + batch = batch_cpu.copy_to(base_model._device, non_blocking=base_model._device.type=='cuda') + representations, ref_alt_seq_embeddings = base_model.calculate_representations(batch, weight_range=base_model._params.reweighting_range) + + representations = representations.cpu() + ref_alt_seq_embeddings = ref_alt_seq_embeddings.cpu() + + labels = [("artifact" if label > 0.5 else "non-artifact") if is_labeled > 0.5 else "unlabeled" for (label, is_labeled) in + zip(batch.get_training_labels().tolist(), batch.get_is_labeled_mask().tolist())] + for (metrics, embeddings) in [(embedding_metrics, representations), (ref_alt_seq_metrics, ref_alt_seq_embeddings)]: + metrics.label_metadata.extend(labels) + metrics.correct_metadata.extend(["unknown"] * batch.size()) + metrics.type_metadata.extend([Variation(idx).name for idx in batch.get_variant_types().tolist()]) + alt_count_strings = [str(round_up_to_nearest_three(min(MAX_COUNT, ac))) for ac in batch.get_alt_counts().tolist()] + metrics.truncated_count_metadata.extend(alt_count_strings) + metrics.representations.append(embeddings) + embedding_metrics.output_to_summary_writer(summary_writer) + ref_alt_seq_metrics.output_to_summary_writer(summary_writer, prefix="ref and alt allele context") + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/dna_sequence_convolution.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/dna_sequence_convolution.py new file mode 100644 index 00000000000..cec4f110fda --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/dna_sequence_convolution.py @@ -0,0 +1,96 @@ +from torch import nn +from math import floor + + +def conv_output_length(input_length, kernel_size=1, stride=1, pad=0, dilation=1, **kwargs): + """ + Output length of 1D convolution given input length and various options. Copied from PyTorch docs + """ + return floor(((input_length + (2 * pad) - (dilation * (kernel_size - 1)) - 1) / stride) + 1) + + +def pool_output_length(input_length, kernel_size=1, stride=None, pad=0, dilation=1, **kwargs): + """ + Output length of 1D pooling given input length and various options. Copied from PyTorch docs. + Differs from convolution in that stride equals kernel_size by default. + """ + return floor(((input_length + (2 * pad) - (dilation * (kernel_size - 1)) - 1) / (kernel_size if stride is None else stride)) + 1) + + +INITIAL_NUM_CHANNELS = 10 # 2 x (4 + 1); the '2' is for ref and alt; the 4 is A/C/G/T; the +1 is for insertion/deletion +TOKEN_SEPARATOR = '/' +KEY_VALUE_SEPARATOR = '=' + + +class DNASequenceConvolution(nn.Module): + """ + A fully-connected network (multi-layer perceptron) that we need frequently + as a sub-network. It is parameterized by the dimensions of its layers, starting with + the input layer and ending with the output. Output is logits and as such no non-linearity + is applied after the last linear transformation. + + Layer strings have the format, for example: + ['convolution/kernel_size=3/out_channels=64', + 'pool/kernel_size=2', + 'leaky_relu', + 'convolution/kernel_size=3/dilation=2/out_channels=5', + 'selu', + 'flatten', + 'linear/out_features=10'] + """ + + def __init__(self, layer_strings, sequence_length): + super(DNASequenceConvolution, self).__init__() + + # note: convention for Pytorch convolutional tensors is (batch, channel, sequence) + last_layer_channels, last_layer_length = (INITIAL_NUM_CHANNELS, sequence_length) # we exclude the batch dimension, which is the first + + layers = [] + for layer_string in layer_strings: + tokens = layer_string.split(TOKEN_SEPARATOR) + layer_type_token = tokens[0] + + kwargs = {} + for key_value_token in tokens[1:]: + key, value = tuple(key_value_token.split(KEY_VALUE_SEPARATOR)) + kwargs[key] = int(value) # we're assuming all params are integers + + if layer_type_token == "convolution": + kwargs["in_channels"] = last_layer_channels + layers.append(nn.Conv1d(**kwargs)) + last_layer_channels, last_layer_length = kwargs["out_channels"], conv_output_length(last_layer_length, **kwargs) + elif layer_type_token == "pool": + assert last_layer_length > 1, "You are trying to pool a length-1 sequence, which, while defined, is silly" + layers.append(nn.MaxPool1d(**kwargs)) + last_layer_length = pool_output_length(last_layer_length, **kwargs) + elif layer_type_token == "leaky_relu": + layers.append(nn.LeakyReLU()) + elif layer_type_token == "selu": + layers.append(nn.SELU()) + elif layer_type_token == "batch_norm": + layers.append(nn.BatchNorm1d(last_layer_channels)) + elif layer_type_token == "flatten": + layers.append(nn.Flatten()) # by default, batch dimension is not flattened + last_layer_channels, last_layer_length = last_layer_channels * last_layer_length, 1 # no position left, everything is a "channel" + elif layer_type_token == "linear": + assert last_layer_length == 1, "Trying to use fully-connected layer before data have been flattened" + kwargs["in_features"] = last_layer_channels + layers.append(nn.Linear(**kwargs)) + last_layer_channels = kwargs["out_features"] + else: + raise Exception("unsupported layer_type: " + layer_type_token) + + assert last_layer_length == 1, "data have not been flattened" + self._output_dimension = last_layer_channels + self._model = nn.Sequential(*layers) + + def output_dimension(self): + return self._output_dimension + + def forward(self, x): + """ + :param x: a batch of DNA sequences represented as a 3D tensor -- 1st index batch, 2nd index channel (A, C, G, T), + 3rd index position in the sequence. + :return: + """ + return self._model.forward(x) \ No newline at end of file diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/gated_mlp.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/gated_mlp.py new file mode 100644 index 00000000000..22719df743d --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/gated_mlp.py @@ -0,0 +1,251 @@ +""" +--- +title: Pay Attention to MLPs (gMLP) +summary: > + This is an annotated implementation/tutorial of Pay Attention to MLPs (gMLP) in PyTorch. +--- + +# Pay Attention to MLPs (gMLP) + +This is a [PyTorch](https://pytorch.org) implementation of the paper +[Pay Attention to MLPs](https://arxiv.org/abs/2105.08050). + +This paper introduces a Multilayer Perceptron (MLP) based architecture with gating, +which they name **gMLP**. It consists of a stack of $L$ *gMLP* blocks. + +Here is [the training code](experiment.html) for a gMLP model based autoregressive model. +""" +# copied from https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/gmlp/__init__.py +# then modified for the symmetric case + +from typing import Optional + +import torch +from torch import nn + +from permutect import utils + + +class GatedMLPBlock(nn.Module): + """ + ## gMLP Block + + Each block does the following transformations to input embeddings + $X \in \mathbb{R}^{n \times d}$ where $n$ is the sequence length + and $d$ is the dimensionality of the embeddings: + + \begin{align} + Z &= \sigma(XU) \\ + \tilde{Z} &= s(Z) \\ + Y &= \tilde{Z}V \\ + \end{align} + + where $V$ and $U$ are learnable projection weights. + $s(\cdot)$ is the Spacial Gating Unit defined below. + Output dimensionality of $s(\cdot)$ will be half of $Z$. + $\sigma$ is an activation function such as + [GeLU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html). + """ + + def __init__(self, d_model: int, d_ffn: int): + """ + * `d_model` is the dimensionality ($d$) of $X$ i.e. the embedding dimension of each read + * `d_ffn` is the dimensionality of $Z$, that is, the hidden dimension of each block + """ + super(GatedMLPBlock, self).__init__() + # Normalization layer fro Pre-Norm + self.norm = nn.LayerNorm([d_model]) + # Activation function $\sigma$ + self.activation = nn.SELU() + # Projection layer for $Z = \sigma(XU)$ + self.proj1 = nn.Linear(d_model, d_ffn) + # Spacial Gating Unit $s(\cdot)$ + self.sgu = SpacialGatingUnit(d_ffn) + # Projection layer for $Y = \tilde{Z}V$ + self.proj2 = nn.Linear(d_ffn // 2, d_model) + # Embedding size (required by [Encoder](../models.html#Encoder). + # We use the encoder module from transformer architecture and plug + # *gMLP* block as a replacement for the [Transformer Layer](../models.html#Encoder). + self.size = d_model + + # X is 2D, counts are the numbers of elements in each consecutive group of rows that form a self-attention group + # that is, is X has 10 rows and counts = [2,3,5], elements 0-1, 2-4, and 5-9 form independent self-attention groups + # In other words, all the reads of a batch are flattened together in X -- the batch information is in counts + def forward(self, x_re: torch.Tensor, counts: torch.IntTensor): + """ + * `x_bre` is the input read embedding tensor of shape Batch x Reads x Embedding + """ + # Norm, projection to d_ffn, and activation $Z = \sigma(XU)$ + z_rd = self.activation(self.proj1(self.norm(x_re))) + # Spacial Gating Unit $\tilde{Z} = s(Z)$ + gated_rd = self.sgu.forward(z_rd, counts) + # Final projection $Y = \tilde{Z}V$ back to embedding dimension + gated_re = self.proj2(gated_rd) + + # Add the shortcut connection + return x_re + gated_re + + +class SpacialGatingUnit(nn.Module): + """ + ## Spatial Gating Unit + ORIGINAL: + $$s(Z) = Z_1 \odot f_{W,b}(Z_2)$$ + + where $f_{W,b}(Z) = W Z + b$ is a linear transformation along the sequence dimension, + and $\odot$ is element-wise multiplication. + $Z$ is split into to parts of equal size $Z_1$ and $Z_2$ along the channel dimension (embedding dimension). + + MODIFIED: f_{W,b} must be permutation-invariant, and the only way to achieve this is if W has a constant diagonal element + and a constant off-diagonal element. That is: WZ = a*(mean of Z along sequence dimension) + b Z + bias + + Due to taking the mean the model no longer needs a constant sequence length. + """ + def __init__(self, d_z: int): + """ + * `d_z` is the dimensionality of $Z$, which is d_ffn of the SGU block + * `seq_len` is the sequence length + """ + super(SpacialGatingUnit, self).__init__() + # Normalization layer before applying $f_{W,b}(\cdot)$ + self.norm = nn.LayerNorm([d_z // 2]) + # Weight $W$ in $f_{W,b}(\cdot)$. + + # TODO: shouldn't alpha and beta be element-by-element??? + self.alpha = nn.Parameter(torch.tensor(0.01)) + self.beta = nn.Parameter(torch.tensor(0.01)) + + # Z is 2D, counts are the numbers of elements in each consecutive group of rows that form a self-attention group + # that is, is X has 10 rows and counts = [2,3,5], elements 0-1, 2-4, and 5-9 form independent self-attention groups + def forward(self, z_rd: torch.Tensor, counts: torch.IntTensor): + # Split $Z$ into $Z_1$ and $Z_2$ over the hidden dimension and normalize $Z_2$ before $f_{W,b}(\cdot)$ + z1_rd, z2_rd = torch.chunk(z_rd, 2, dim=-1) + z2_rd = self.norm(z2_rd) + + # TODO: self.beta needs to multiply the mean field here!!! + z2_rd = 1 + self.alpha * z2_rd + utils.means_over_rows(z2_rd, counts, keepdim=True) + + # $Z_1 \odot f_{W,b}(Z_2)$ + return z1_rd * z2_rd + + +class GatedMLP(nn.Module): + def __init__(self, d_model: int, d_ffn: int, num_blocks: int): + super(GatedMLP, self).__init__() + + self.blocks = nn.ModuleList([GatedMLPBlock(d_model, d_ffn) for _ in range(num_blocks)]) + + # X is 2D, counts are the numbers of elements in each consecutive group of rows that form a self-attention group + # that is, is X has 10 rows and counts = [2,3,5], elements 0-1, 2-4, and 5-9 form independent self-attention groups + def forward(self, x, counts): + for block in self.blocks: + x = block.forward(x, counts) + return x + + +class GatedRefAltMLPBlock(nn.Module): + """ + Like the above, but ref reads see the mean field of ref reads and alt reads see the mean fields of both ref and alt + Note that using mean fields implies the model never *counts* ref or alt reads, nor knows their relative frequencies + """ + + def __init__(self, d_model: int, d_ffn: int): + """ + * `d_model` is the dimensionality of read embeddings + * `d_ffn` is the hidden dimension of each block + """ + super(GatedRefAltMLPBlock, self).__init__() + # Normalization layer fro Pre-Norm + self.norm = nn.LayerNorm([d_model]) + # Activation function $\sigma$ + self.activation = nn.SELU() + # Projection layer for $Z = \sigma(XU)$ + self.proj1_ref = nn.Linear(d_model, d_ffn) + self.proj1_alt = nn.Linear(d_model, d_ffn) + # Spacial Gating Unit $s(\cdot)$ + self.sgu = SpacialGatingUnitRefAlt(d_ffn) + # Projection layer for $Y = \tilde{Z}V$ + self.proj2_ref = nn.Linear(d_ffn // 2, d_model) + self.proj2_alt = nn.Linear(d_ffn // 2, d_model) + # Embedding size (required by [Encoder](../models.html#Encoder). + # We use the encoder module from transformer architecture and plug + # *gMLP* block as a replacement for the [Transformer Layer](../models.html#Encoder). + self.size = d_model + + def forward(self, ref_re: torch.Tensor, alt_re: torch.Tensor, ref_counts: torch.IntTensor, alt_counts: torch.IntTensor): + """ + * `x_bre` is the input read embedding tensor of shape Batch x Reads x Embedding + """ + # Norm, projection to d_ffn, and activation $Z = \sigma(XU)$ + zref_rd = self.activation(self.proj1_ref(self.norm(ref_re))) + zalt_rd = self.activation(self.proj1_alt(self.norm(alt_re))) + + # Spacial Gating Unit $\tilde{Z} = s(Z)$ + gated_ref_rd, gated_alt_rd = self.sgu.forward(zref_rd, zalt_rd, ref_counts, alt_counts) + # Final projection $Y = \tilde{Z}V$ back to embedding dimension + gated_ref_re = self.proj2_ref(gated_ref_rd) + gated_alt_re = self.proj2_alt(gated_alt_rd) + + # Add the shortcut connection + return ref_re + gated_ref_re, alt_re + gated_alt_re + + +class SpacialGatingUnitRefAlt(nn.Module): + """ + """ + def __init__(self, d_z: int): + """ + * `d_z` is the dimensionality of $Z$, which is d_ffn of the SGU block + * `seq_len` is the sequence length + """ + super(SpacialGatingUnitRefAlt, self).__init__() + # Normalization layer before applying $f_{W,b}(\cdot)$ + self.norm = nn.LayerNorm([d_z // 2]) + # Weight $W$ in $f_{W,b}(\cdot)$. + + # TODO: maybe let these parameters be element-by-element vectors? + self.alpha_ref = nn.Parameter(torch.tensor(0.01)) + self.alpha_alt = nn.Parameter(torch.tensor(0.01)) + self.beta_ref = nn.Parameter(torch.tensor(0.01)) + self.beta_alt = nn.Parameter(torch.tensor(0.01)) + self.gamma = nn.Parameter(torch.tensor(0.01)) + + # regularizer / sort of imputed value for when there are no ref counts + self.ref_regularizer = nn.Parameter(0.1 * torch.ones(d_z // 2)) + self.regularizer_weight_pre_exp = nn.Parameter(torch.log(torch.tensor(0.1))) + + def forward(self, zref_rd: torch.Tensor, zalt_rd: torch.Tensor, ref_counts: torch.IntTensor, alt_counts: torch.IntTensor): + + # Split $Z$ into $Z_1$ and $Z_2$ over the hidden dimension and normalize $Z_2$ before $f_{W,b}(\cdot)$ + z1_ref_rd, z2_ref_rd = torch.chunk(zref_rd, 2, dim=-1) + z1_alt_rd, z2_alt_rd = torch.chunk(zalt_rd, 2, dim=-1) + z2_ref_rd = self.norm(z2_ref_rd) + z2_alt_rd = self.norm(z2_alt_rd) + + # these are means by variant -- need repeat_interleave to make them by-read + ref_mean_field_vd = utils.means_over_rows_with_regularizer(z2_ref_rd, ref_counts, self.ref_regularizer, torch.exp(self.regularizer_weight_pre_exp) + 0.25) + alt_mean_field_vd = utils.means_over_rows(z2_alt_rd, alt_counts) + + ref_mean_field_on_ref_rd = torch.repeat_interleave(ref_mean_field_vd, dim=0, repeats=ref_counts) + ref_mean_field_on_alt_rd = torch.repeat_interleave(ref_mean_field_vd, dim=0, repeats=alt_counts) + alt_mean_field_on_alt_rd = torch.repeat_interleave(alt_mean_field_vd, dim=0, repeats=alt_counts) + + # same as above except now there is an additional term for the ref mean field influence on alt + # maybe later also let alt mean field influence ref + z2_ref_rd = 1 + self.alpha_ref * z2_ref_rd + self.beta_ref * ref_mean_field_on_ref_rd + z2_alt_rd = 1 + self.alpha_alt * z2_alt_rd + self.beta_alt * alt_mean_field_on_alt_rd + self.gamma * ref_mean_field_on_alt_rd + + # $Z_1 \odot f_{W,b}(Z_2)$ + return z1_ref_rd * z2_ref_rd, z1_alt_rd * z2_alt_rd + + +class GatedRefAltMLP(nn.Module): + def __init__(self, d_model: int, d_ffn: int, num_blocks: int): + super(GatedRefAltMLP, self).__init__() + + self.blocks = nn.ModuleList([GatedRefAltMLPBlock(d_model, d_ffn) for _ in range(num_blocks)]) + + def forward(self, ref, alt, ref_counts, alt_counts): + for block in self.blocks: + ref, alt = block(ref, alt, ref_counts, alt_counts) + return ref, alt diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/gradient_reversal/__init__.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/gradient_reversal/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/gradient_reversal/functional.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/gradient_reversal/functional.py new file mode 100644 index 00000000000..195a3d767c3 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/gradient_reversal/functional.py @@ -0,0 +1,25 @@ +from typing import Any + +from torch.autograd import Function + + +class GradientReversal(Function): + @staticmethod + def jvp(ctx: Any, *grad_inputs: Any) -> Any: + pass + + @staticmethod + def forward(ctx, x, alpha): + ctx.save_for_backward(x, alpha) + return x + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + _, alpha = ctx.saved_tensors + if ctx.needs_input_grad[0]: + grad_input = - alpha * grad_output + return grad_input, None + + +revgrad = GradientReversal.apply diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/gradient_reversal/module.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/gradient_reversal/module.py new file mode 100644 index 00000000000..070c7cf3c08 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/gradient_reversal/module.py @@ -0,0 +1,15 @@ +from .functional import revgrad +import torch +from torch import nn + + +class GradientReversal(nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = torch.tensor(alpha, requires_grad=False) + + def forward(self, x): + return revgrad(x, self.alpha) + + def set_alpha(self, alpha_new): + self.alpha = torch.tensor(alpha_new, requires_grad=False) diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/mlp.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/mlp.py new file mode 100644 index 00000000000..957084bcf13 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/mlp.py @@ -0,0 +1,67 @@ +import torch +from torch import nn, Tensor +from typing import List + + +class DenseSkipBlock(nn.Module): + """ + computes x + f(x) where f(x) has some given number of linear layers, each with input and output dimension equal + to that of the input x. As suggested in arxiv:1603.05027, Identity Maps in Deep Residual Networks, nonlinearities come before each linear transformation + """ + def __init__(self, input_size: int, num_layers: int, batch_normalize: bool = False, dropout_p: float = 0): + super(DenseSkipBlock, self).__init__() + self.mlp = MLP((num_layers + 1) * [input_size], batch_normalize, dropout_p, prepend_activation=True) + + # scale the MLP and initially set it to a small amount so that the block is close to an identity map early in learning + self.alpha = nn.Parameter(torch.tensor(0.1)) + + def forward(self, x): + return x + self.alpha * self.mlp.forward(x) + + +class MLP(nn.Module): + """ + A fully-connected network (multi-layer perceptron) that we need frequently + as a sub-network. It is parameterized by the dimensions of its layers, starting with + the input layer and ending with the output. Output is logits and as such no non-linearity + is applied after the last linear transformation. + """ + + def __init__(self, layer_sizes: List[int], batch_normalize: bool = False, dropout_p: float = 0, prepend_activation: bool = False): + super(MLP, self).__init__() + + layers = [nn.SELU()] if prepend_activation else [] + self._input_dim = layer_sizes[0] + input_dim = layer_sizes[0] + for k, output_dim in enumerate(layer_sizes[1:]): + # negative output dimension -d will denote a d-layer residual skip connection + # the output dimension of which equals the current input dimension + if output_dim < 0: + layers.append(DenseSkipBlock(input_dim, -output_dim, batch_normalize, dropout_p)) + continue + + if batch_normalize: + layers.append(nn.BatchNorm1d(num_features=input_dim)) + + layers.append(nn.Linear(input_dim, output_dim)) + + if dropout_p > 0: + layers.append(nn.Dropout(p=dropout_p)) + + # k runs from 0 to len(layer_sizes) - 2. Omit the nonlinearity after the last layer. + if k < len(layer_sizes) - 2: + layers.append(nn.SELU()) + + input_dim = output_dim # note that this does not happen for a residual skip connection + + self._output_dim = input_dim + self._model = nn.Sequential(*layers) + + def input_dimension(self) -> int: + return self._input_dim + + def output_dimension(self) -> int: + return self._output_dim + + def forward(self, x: Tensor) -> Tensor: + return self._model.forward(x) \ No newline at end of file diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/monotonic.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/monotonic.py new file mode 100644 index 00000000000..e2c4afeea5b --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/monotonic.py @@ -0,0 +1,129 @@ +from torch import nn +import torch.nn.functional as F +import torch +import math +from typing import List + + +class MonoDenseLayer(nn.Module): + """ + MonoDenseLayer from Constrained Monotonic Neural Networks, Runje and Shankaranarayana, https://arxiv.org/abs/2205.11775 + + It is a modification of a plain old linear layer. + + 1) The output is constrained to be monotonically increasing, decreasing, or unconstrained with respect to each input + + 2) Input vectors are assumed ordered with increasing features, then decreasing, then unconstrained + """ + + def __init__(self, input_dimension: int, output_dimension: int, num_increasing: int, num_decreasing: int, omit_activation: bool = False): + super(MonoDenseLayer, self).__init__() + + self.convex_activation = torch.relu + + self.omit_activation = omit_activation + + self.num_constrained = num_increasing + num_decreasing + num_free = input_dimension - self.num_constrained + assert self.num_constrained <= input_dimension + assert self.num_constrained > 0 + + self.input_dimension = input_dimension + self.output_dimension = output_dimension + + # mask has -1's for decreasing features, otherwise 1's + # in the forward pass we multiply by the mask for convenience so that monotonically increasing AND decreasing can both + # be treated as increasing + self.mask = nn.Parameter(torch.ones(input_dimension), requires_grad=False) + self.mask[num_increasing: num_increasing + num_decreasing] = -1 + + self.monotonic_W = nn.Parameter(torch.empty((output_dimension, self.num_constrained))) + nn.init.kaiming_uniform_(self.monotonic_W, a=math.sqrt(5)) + + self.free_W = nn.Parameter(torch.empty((output_dimension, input_dimension - self.num_constrained))) if num_free > 0 else None + if self.free_W is not None: + nn.init.kaiming_uniform_(self.free_W, a=math.sqrt(5)) + + self.b = nn.Parameter(torch.empty(output_dimension)) + bound = 1 / math.sqrt(input_dimension) + nn.init.uniform_(self.b, -bound, bound) + + def forward(self, x): + flipped = x * self.mask + + # note that monotonicity is enforced by taking the absolute value of the monotonic weight matrix + monotonic_contribution = F.linear(flipped[:, :self.num_constrained], torch.abs(self.monotonic_W)) + free_contribution = F.linear(flipped[:, self.num_constrained:], self.free_W) if self.free_W is not None else 0 + + before_activation = monotonic_contribution + free_contribution + self.b + + if self.omit_activation: + return before_activation + + # as in the paper, we apply three nonlinear activation functions: 1) an ordinary convex activation g(x) which + # could be a ReLU, leaky ReLU, tanh etc; 2) the concave reflection -g(-x); 3) g(x+1)-g(1) (if x < 0) or g(1) - g(1-x) (if x > 0) + + features_per_activation = self.output_dimension // 3 + + left = before_activation[:, :features_per_activation] + middle = before_activation[:, features_per_activation:(2*features_per_activation)] + right = before_activation[:, (2*features_per_activation):] + + output1 = self.convex_activation(left) + output2 = -self.convex_activation(-middle) + output3 = torch.sgn(right)*(self.convex_activation(torch.ones_like(right)) - self.convex_activation(1-torch.abs(right))) + + return torch.hstack([output1, output2, output3]) + + +class MonotonicHighwayLayer(nn.Module): + """ + This is a purely monotonic increasing layer, like all layers but the first in the MonoDense architecture below. + + It is a highway network layer of the form: + output = (1 - gate) * input + gate * nonlinear(input) + where the gate is the sigmoid of some linear transformation of the input, gating is element-by-element, and the + nonlinear function of the input is one or more monotonic dense layers as above. + """ + def __init__(self, dim: int, num_layers: int): + super(MonotonicHighwayLayer, self).__init__() + self.nonlinear = MonoDense(input_dimension=dim, output_dimensions=(num_layers * [dim]), num_increasing=dim, num_decreasing=0) + + # initialize with negative bias so behavior starts near identity with gates almost closed + self.gate_pre_sigmoid = nn.Parameter(torch.tensor(-2.0)) + + def forward(self, x): + gate = torch.sigmoid(self.gate_pre_sigmoid) + + return (1 - gate) * x + gate * self.nonlinear(x) + + +class MonoDense(nn.Module): + """ + + """ + + def __init__(self, input_dimension: int, output_dimensions: List[int], num_increasing: int, num_decreasing): + super(MonoDense, self).__init__() + + self.input_dimension = input_dimension + self.layers = torch.nn.Sequential() + + last_layer_dim = input_dimension + for layer, output_dim in enumerate(output_dimensions): + omit_activation = (layer == len(output_dimensions) - 1) + + if output_dim > 0: + # layers after the first are purely monotonic increasing + n_increasing = num_increasing if layer == 0 else last_layer_dim + n_decreasing = num_decreasing if layer == 0 else 0 + self.layers.append(MonoDenseLayer(last_layer_dim, output_dim, n_increasing, n_decreasing, omit_activation=omit_activation)) + last_layer_dim = output_dim + else: # negative output dimension denotes monotonic highway layer + if layer == 0: + assert num_increasing == input_dimension, "initial highway layer is only valid for purely increasing network" + num_hidden_layers = -output_dim + self.layers.append(MonotonicHighwayLayer(dim=last_layer_dim, num_layers=num_hidden_layers)) + + def forward(self, x): + return self.layers.forward(x) diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/normal_artifact_spectrum.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/normal_artifact_spectrum.py new file mode 100644 index 00000000000..4f7e0447068 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/normal_artifact_spectrum.py @@ -0,0 +1,59 @@ +import math + +import torch +from torch import nn + +import matplotlib.pyplot as plt +import numpy as np + +EPSILON = 0.001 + + +class NormalArtifactSpectrum(nn.Module): + def __init__(self, num_samples: int): + super(NormalArtifactSpectrum, self).__init__() + + self.num_samples = num_samples + + self.W = nn.Linear(in_features=2, out_features=2) + + # this initializes to be sort of uniform on [0,1]x[0,1], with some bias toward lower allele fractions + # if we don't initialize carefully all the weight is near (0.5,0.5) and the model gives basically zero + # likelihood to low allele fractions + with torch.no_grad(): + self.W.weight.copy_(torch.Tensor([[1.7, 0], [0, 1.7]])) + self.W.bias.copy_(torch.Tensor([-0.1, -0.1])) + + def forward(self, tumor_alt_1d: torch.Tensor, tumor_ref_1d: torch.Tensor, normal_alt_1d: torch.Tensor, normal_ref_1d: torch.Tensor): + if torch.sum(normal_alt_1d) < 1: # shortcut if no normal alts in the whole batch + return -9999 * torch.ones_like(tumor_alt_1d) + batch_size = len(tumor_alt_1d) + tumor_fractions_2d, normal_fractions_2d = self.get_tumor_and_normal_fraction(batch_size, self.num_samples) + + log_likelihoods_2d = torch.reshape(tumor_alt_1d, (batch_size, 1)) * torch.log(tumor_fractions_2d) \ + + torch.reshape(tumor_ref_1d, (batch_size, 1)) * torch.log(1 - tumor_fractions_2d) \ + + torch.reshape(normal_alt_1d, (batch_size, 1)) * torch.log(normal_fractions_2d) \ + + torch.reshape(normal_ref_1d, (batch_size, 1)) * torch.log(1 - normal_fractions_2d) + + # average over sample dimension + log_likelihoods_1d = torch.logsumexp(log_likelihoods_2d, dim=1) - math.log(self.num_samples) + + # zero likelihood if no alt in normal + no_alt_in_normal_mask = normal_alt_1d < 1 + return -9999 * no_alt_in_normal_mask + log_likelihoods_1d * torch.logical_not(no_alt_in_normal_mask) + + def get_tumor_and_normal_fraction(self, batch_size, num_samples): + gaussian_3d = torch.randn(batch_size, num_samples, 2) + correlated_gaussian_3d = self.W.forward(gaussian_3d) + # to prevent nans, map onto [EPSILON, 1 - EPSILON] + tumor_fractions_2d = EPSILON + (1 - 2 * EPSILON) * torch.sigmoid(correlated_gaussian_3d[:, :, 0]) + normal_fractions_2d = EPSILON + (1 - 2 * EPSILON) * torch.sigmoid(correlated_gaussian_3d[:, :, 1]) + return tumor_fractions_2d, normal_fractions_2d + + # TODO: move this method to plotting + def density_plot_on_axis(self, ax): + tumor_fractions_2d, normal_fractions_2d = self.get_tumor_and_normal_fraction(batch_size=1, num_samples=100000) + tumor_f = torch.squeeze(tumor_fractions_2d).detach().numpy() + normal_f = torch.squeeze(normal_fractions_2d).detach().numpy() + + ax.hist2d(tumor_f, normal_f, bins=(100, 100), range=[[0, 1], [0, 1]], density=True, cmap=plt.cm.jet) diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/normal_seq_error_spectrum.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/normal_seq_error_spectrum.py new file mode 100644 index 00000000000..40fd1c4125c --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/normal_seq_error_spectrum.py @@ -0,0 +1,62 @@ +import math + +import torch +from torch import nn + +import matplotlib.pyplot as plt +import numpy as np + +EPSILON = 0.0001 + +# the mean of a half-normal distribution is related to the standard deviation sigma of its corresponding normal distribution by +# sigma = mean * sqrt(pi/2) +SQRT_PI_OVER_2 = math.sqrt(math.pi / 2) + + +# we can't use a beta binomial for normal seq error because betas have such long tails that even if we constrain the mean +# to be small there is too large a probability of a large allele fraction. Here we assume an underlying half normal distribution on +# the allele fraction ie it is a half normal-binomial. Since these are not conjugate we have to explicitly sample and +# essentially perform a brute force Monte Carlo integral. +class NormalSeqErrorSpectrum(nn.Module): + def __init__(self, num_samples: int, max_mean: float): + super(NormalSeqErrorSpectrum, self).__init__() + + self.num_samples = num_samples + + self.max_mean = max_mean + + # this is 1/lambda parameter + # TODO: magic constant initialization!!! + self.mean_pre_sigmoid = torch.nn.Parameter(torch.tensor(0.0)) + + def forward(self, alt_counts_1d: torch.Tensor, ref_counts_1d: torch.Tensor): + batch_size = len(alt_counts_1d) + fractions_2d = self.get_fractions(batch_size, self.num_samples) + + log_likelihoods_2d = torch.reshape(alt_counts_1d, (batch_size, 1)) * torch.log(fractions_2d) \ + + torch.reshape(ref_counts_1d, (batch_size, 1)) * torch.log(1 - fractions_2d) + + # average over sample dimension + log_likelihoods_1d = torch.logsumexp(log_likelihoods_2d, dim=1) - math.log(self.num_samples) + + combinatorial_term = torch.lgamma(alt_counts_1d + ref_counts_1d + 1) - torch.lgamma(alt_counts_1d + 1) - torch.lgamma(ref_counts_1d + 1) + + return combinatorial_term + log_likelihoods_1d + + def get_mean(self): + return torch.sigmoid(self.mean_pre_sigmoid) * self.max_mean + + def get_fractions(self, batch_size, num_samples): + actual_mean = torch.sigmoid(self.mean_pre_sigmoid) * self.max_mean + actual_sigma = SQRT_PI_OVER_2 * actual_mean + normal_samples = torch.randn(batch_size, num_samples, device=actual_sigma.device) + half_normal_samples = torch.abs(normal_samples) + fractions_2d_unbounded = actual_sigma * half_normal_samples + # apply tanh to constrain fractions to [0, 1), and then to [EPSILON, 1 - EPSILON] for numerical stability + fractions_2d = EPSILON + (1 - 2*EPSILON)*torch.tanh(fractions_2d_unbounded) + return fractions_2d + + # TODO: move this method to plotting + def density_plot_on_axis(self, ax): + fractions = torch.squeeze(self.get_fractions(1, 100000)).detach().numpy() + ax.hist(fractions, bins=1000, range=[0, 1]) diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/overdispersed_binomial_mixture.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/overdispersed_binomial_mixture.py new file mode 100644 index 00000000000..de7bcaaa9ca --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/overdispersed_binomial_mixture.py @@ -0,0 +1,203 @@ +import math +from typing import List + +import torch +from permutect import utils +from torch import nn, exp, unsqueeze, logsumexp +from torch.nn.functional import softmax, log_softmax + +from permutect.architecture.mlp import MLP +from permutect.metrics.plotting import simple_plot +from permutect.utils import beta_binomial, gamma_binomial, binomial, Variation + + +class OverdispersedBinomialMixture(nn.Module): + """ + This model takes in 1D tensor inputs (variant type indices by batch) and as a function of input has a Beta OR Gamma mixture model. That is, it computes for each input + vector 1) a vector of mixture component weights 2) a vector of the alpha shape parameters of each component and 3) a + vector of beta shape parameters of each component. Due to batching these are all represented as 2D tensors. + + Note that both the Beta and Gamma distributions' shape parameters are traditionally called "alpha" and "beta". + + The computed likelihoods take in a 1D batch of total counts n and 1D batch of "success" counts k. + + It optionally has a max mean that scales every mean to some amount less than or equal to 1, which is useful when we want + to force the mixture to represent only small fractions. + + When using a BetaBinomial mixture, due to conjugacy the integral over the latent probability of success (in our uses, + this is the allele fraction of variants or artifacts) is exact and we use a closed form analytic expression for the + density of a BetaBinomial. That is, the probability (k = alt count, n = depth, f = latent allele fraction) + + P(k|n, alpha, beta) = integral{Beta(f|alpha, beta) * Binomial(k|n, f)} + + is exact. + + When using a GammaBinomial mixture, i.e. one with a Gamma prior Gamma(f, alpha, beta) the cannot do the integral exactly. + However, the *binomial* factor Binom(k|n,f), which as a function of f is a Beta distribution, is extremely well-approximated + by a Gamma distribution, and the product of this Gamma approximation and the Gamma prior on f *is* exactly integrable. + + The approximation breaks down if the allele fractions are not small (since then the support of the Gamma for f > 1 + becaomes significant), so we should only use the Gamma prior version to model artifact allele fractions. + + In addition to 'beta' and 'gamma' modes, there is also the 'none' mode which has no overdispersion in the individual components. + That is, each component is a plain binomial, though of course by virtue of being a mixture the distribution as a whole is overdispersed. + """ + + def __init__(self, num_components: int, max_mean: float = 1, mode: str = 'beta'): + super(OverdispersedBinomialMixture, self).__init__() + self.mode = mode + self.K = num_components + self.V = len(Variation) + self.max_mean = max_mean + + # parameters for each component and variant type: + self.weights_pre_softmax_vk = torch.nn.Parameter(torch.ones(self.V, self.K)) + self.mean_pre_sigmoid_vk = torch.nn.Parameter(torch.randn(self.V, self.K)) + self.concentration_pre_sigmoid_vk = torch.nn.Parameter(torch.randn(self.V, self.K)) + self.max_concentration = torch.nn.Parameter(torch.tensor(50.0)) + + ''' + here x is a 2D tensor, 1st dimension batch, 2nd dimension being features that determine which Beta mixture to use + n and k are 1D tensors, the only dimension being batch. + ''' + def forward(self, types_b, n_b, k_b): + types_idx = types_b.long() + log_weights_bk = log_softmax(self.weights_pre_softmax_vk[types_idx, :], dim=-1) + + # we make them 2D, with 1st dim batch, to match alpha and beta. A single column is OK because the single value of + # n/k are broadcast over all mixture components + n_bk = n_b[:, None] + k_bk = k_b[:, None] + + # 2D tensors -- 1st dim batch, 2nd dim mixture component + mean_bk = self.max_mean * torch.sigmoid(self.mean_pre_sigmoid_vk[types_idx, :]) + concentration_bk = self.get_concentration(types_b) + + if self.mode == 'beta': + alpha_bk = mean_bk * concentration_bk + beta_bk = (1 - mean_bk) * concentration_bk + log_likelihoods_bk = beta_binomial(n_bk, k_bk, alpha_bk, beta_bk) + elif self.mode == 'gamma': + alpha_bk = mean_bk * concentration_bk + beta_bk = concentration_bk + log_likelihoods_bk = gamma_binomial(n_bk, k_bk, alpha_bk, beta_bk) + elif self.mode == 'none': + # each mean is the center of a binomial + log_likelihoods_bk = binomial(n_bk, k_bk, mean_bk) + else: + raise Exception("we don't have that kind of mode!") + + log_weighted_likelihoods_bk = log_weights_bk + log_likelihoods_bk + + # yields one number per batch, squeezed into 1D output tensor + return logsumexp(log_weighted_likelihoods_bk, dim=-1, keepdim=False) + + def get_concentration(self, types_b): + return self.max_concentration * torch.sigmoid(self.concentration_pre_sigmoid_vk[types_b.long(),:]) + + # given 1D input tensor, return 1D tensors of component alphas and betas + def component_shapes(self, var_type: int): + means_k = self.max_mean * torch.sigmoid(self.mean_pre_sigmoid_vk[var_type]) + concentrations_k = self.max_concentration * torch.sigmoid(self.concentration_pre_sigmoid_vk[var_type]) + alphas_k = means_k * concentrations_k + betas_k = (1 - means_k) * concentrations_k if self.mode == 'beta' else concentrations_k + return alphas_k, betas_k + + def component_weights(self, var_type: int): + return softmax(self.weights_pre_softmax_vk[var_type], dim=-1) + + # given variant type, return the moments E[x], E[ln(x)], and E[x ln(x)] of the underlying beta mixture + def moments_of_underlying_beta_mixture(self, var_type: int): + assert self.mode == 'beta' + alphas, betas = self.component_shapes(var_type) + weights = self.component_weights(var_type) + + # E[x] + component_means = alphas / (alphas + betas) + mixture_mean = torch.sum(weights * component_means) + + # E[ln(x)] + component_log_means = torch.digamma(alphas) - torch.digamma( + alphas + betas) # digamma broadcasts to make 1D tensor + mixture_log_mean = torch.sum(weights * component_log_means) + + # E[x ln(x)] + component_log_linear_means = component_means * (torch.digamma(alphas + 1) - torch.digamma(alphas + betas + 1)) + mixture_log_linear_mean = torch.sum(weights * component_log_linear_means) + + return mixture_mean, mixture_log_mean, mixture_log_linear_mean + + ''' + here x is a 2D tensor, 1st dimension batch, 2nd dimension being features that determine which Beta mixture to use + n is a 1D tensor, the only dimension being batch, and we sample a 1D tensor of k's + ''' + def sample(self, types_b, n): + # compute weights and select one mixture component from the corresponding multinomial for each datum / row + weights = softmax(self.weights_pre_softmax_vk[types_b, :], dim=-1) + component_indices = torch.multinomial(weights, num_samples=1, replacement=True) # 2D tensor with one column + + # get 1D tensors of one selected alpha and beta shape parameter per datum / row, then sample a fraction from each + # It may be very wasteful computing everything and only using one component, but this is just for unit testing + means = self.max_mean * torch.sigmoid(self.mean_pre_sigmoid_vk[types_b, :].detach()).gather(dim=1, index=component_indices).squeeze() + concentrations = self.get_concentration(types_b).detach().gather(dim=1, index=component_indices).squeeze() + alphas = means * concentrations + betas = (1 - means) * concentrations if self.mode == 'beta' else concentrations + dist = torch.distributions.beta.Beta(alphas, betas) if self.mode == 'beta' else torch.distributions.gamma.Gamma(alphas, betas) + fractions = dist.sample() # 1D tensor + + # recall, n and fractions are 1D tensors; result is also 1D tensor, one "success" count per datum + return torch.distributions.binomial.Binomial(total_count=n, probs=fractions).sample() + + def fit(self, num_epochs, types_b, depths_1d_tensor, alt_counts_1d_tensor, batch_size=64): + optimizer = torch.optim.Adam(self.parameters()) + num_batches = math.ceil(len(alt_counts_1d_tensor) / batch_size) + + for epoch in range(num_epochs): + for batch in range(num_batches): + batch_start = batch * batch_size + batch_end = min(batch_start + batch_size, len(alt_counts_1d_tensor)) + batch_slice = slice(batch_start, batch_end) + loss = -torch.mean(self.forward(types_b[batch_slice], depths_1d_tensor[batch_slice], + alt_counts_1d_tensor[batch_slice])) + utils.backpropagate(optimizer, loss) + + ''' + get raw data for a spectrum plot of probability density vs allele fraction. + here x is a 1D tensor, a single datum/row of the 2D tensors as above + ''' + def spectrum_density_vs_fraction(self, variant_type: Variation, depth: int): + # device = self.mean_pre_sigmoid_vk.device + fractions = torch.arange(0.01, 0.99, 0.001) # 1D tensor on CPU + + log_weights_k = log_softmax(self.weights_pre_softmax_vk[variant_type].detach(), dim=-1).cpu() + means_k = self.max_mean * torch.sigmoid(self.mean_pre_sigmoid_vk[variant_type].detach()).cpu() + + # now we're on CPU + if self.mode == 'none': + # this is copied from the beta case below -- basically we smear each delta function / discrete binomial + # into a narrow Gaussian + dist = torch.distributions.normal.Normal(means_k, 0.01 * torch.ones_like(means_k)) + densities = exp(torch.logsumexp(log_weights_k + dist.log_prob(fractions.unsqueeze(dim=1)), dim=1, + keepdim=False)) # 1D tensor + return fractions, densities + else: + concentrations_k = self.max_concentration.cpu() * torch.sigmoid(self.concentration_pre_sigmoid_vk[variant_type]).detach().cpu() + alphas_k = means_k * concentrations_k + betas_k = (1 - means_k) * concentrations_k if self.mode == 'beta' else concentrations_k + + # since f.unsqueeze(dim=1) is 2D column vector, log_prob produces 2D tensor where row index is f and column index is mixture component + # adding the single-row 2D tensor log_weights broadcasts to each row / value of f + # then we apply log_sum_exp, dim= 1, to sum over components and get a log_density for each f + dist = torch.distributions.beta.Beta(alphas_k, betas_k) if self.mode == 'beta' else torch.distributions.gamma.Gamma(alphas_k, betas_k) + densities = exp(torch.logsumexp(log_weights_k + dist.log_prob(fractions.unsqueeze(dim=1)), dim=1, keepdim=False)) # 1D tensor + + return fractions, densities + + ''' + here x is a 1D tensor, a single datum/row of the 2D tensors as above + ''' + def plot_spectrum(self, x, title, depth: int): + fractions, densities = self.spectrum_density_vs_fraction(x, depth) + return simple_plot([(fractions.numpy(), densities.numpy(), " ")], "AF", "density", title) + + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/posterior_model.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/posterior_model.py new file mode 100644 index 00000000000..39500f3f146 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/posterior_model.py @@ -0,0 +1,367 @@ +from collections import defaultdict +from itertools import chain +from math import ceil + +import torch +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter +from tqdm.autonotebook import trange, tqdm + +from permutect import utils +from permutect.architecture.artifact_spectra import ArtifactSpectra +from permutect.architecture.overdispersed_binomial_mixture import OverdispersedBinomialMixture +from permutect.architecture.normal_seq_error_spectrum import NormalSeqErrorSpectrum +from permutect.architecture.somatic_spectrum import SomaticSpectrum +from permutect.data.base_datum import DEFAULT_GPU_FLOAT, DEFAULT_CPU_FLOAT +from permutect.data.posterior import PosteriorBatch +from permutect.metrics import plotting +from permutect.utils import Variation, Call +from permutect.metrics.evaluation_metrics import MAX_COUNT, NUM_COUNT_BINS, multiple_of_three_bin_index, multiple_of_three_bin_index_to_count + + +# TODO: write unit test asserting that this comes out to zero when counts are zero +# given germline, the probability of these particular reads being alt +def germline_log_likelihood(afs, mafs, alt_counts, depths, het_beta=None): + hom_alpha, hom_beta = torch.tensor([98.0], device=depths.device), torch.tensor([2.0], device=depths.device) + het_alpha, het_beta_to_use = (None, None) if het_beta is None else (torch.tensor([het_beta], device=depths.device), torch.tensor([het_beta], device=depths.device)) + het_probs = 2 * afs * (1 - afs) + hom_probs = afs * afs + het_proportion = het_probs / (het_probs + hom_probs) + hom_proportion = 1 - het_proportion + + log_mafs = torch.log(mafs) + log_1m_mafs = torch.log(1 - mafs) + log_half_het_prop = torch.log(het_proportion / 2) + + ref_counts = depths - alt_counts + + combinatorial_term = torch.lgamma(depths + 1) - torch.lgamma(alt_counts + 1) - torch.lgamma(ref_counts + 1) + # the following should both be 1D tensors of length batch size + alt_minor_binomial = combinatorial_term + alt_counts * log_mafs + ref_counts * log_1m_mafs + alt_major_binomial = combinatorial_term + ref_counts * log_mafs + alt_counts * log_1m_mafs + alt_minor_ll = log_half_het_prop + (alt_minor_binomial if het_beta is None else utils.beta_binomial(depths, alt_counts, het_alpha, het_beta_to_use)) + alt_major_ll = log_half_het_prop + (alt_major_binomial if het_beta is None else utils.beta_binomial(depths, alt_counts, het_alpha, het_beta_to_use)) + hom_ll = torch.log(hom_proportion) + utils.beta_binomial(depths, alt_counts, hom_alpha, hom_beta) + + return torch.logsumexp(torch.vstack((alt_minor_ll, alt_major_ll, hom_ll)), dim=0) + + +# TODO: max_mean is hard-coded magic constant!! +def initialize_normal_artifact_spectra(): + return OverdispersedBinomialMixture(num_components=1, max_mean=0.1, mode='beta') + + +# this works for ArtifactSpectra and OverdispersedBinomialMixture +def plot_artifact_spectra(artifact_spectra, depth: int = None): + # plot AF spectra in two-column grid with as many rows as needed + art_spectra_fig, art_spectra_axs = plt.subplots(ceil(len(Variation) / 2), 2, sharex='all', sharey='all') + for variant_type in Variation: + n = variant_type + row, col = int(n / 2), n % 2 + frac, dens = artifact_spectra.spectrum_density_vs_fraction(variant_type, depth) + art_spectra_axs[row, col].plot(frac.detach().numpy(), dens.detach().numpy(), label=variant_type.name) + art_spectra_axs[row, col].set_title(variant_type.name + " artifact AF spectrum") + for ax in art_spectra_fig.get_axes(): + ax.label_outer() + return art_spectra_fig, art_spectra_axs + + +class PosteriorModel(torch.nn.Module): + """ + + """ + def __init__(self, variant_log_prior: float, artifact_log_prior: float, num_base_features: int, no_germline_mode: bool = False, device=utils.gpu_if_available(), het_beta: float = None): + super(PosteriorModel, self).__init__() + + self._device = device + self._dtype = DEFAULT_GPU_FLOAT if device != torch.device("cpu") else DEFAULT_CPU_FLOAT + self.no_germline_mode = no_germline_mode + self.num_base_features = num_base_features + self.het_beta = het_beta + + # TODO introduce parameters class so that num_components is not hard-coded + self.somatic_spectrum = SomaticSpectrum(num_components=5) + + # artifact spectra for each variant type. Variant type encoded as one-hot input vector. + self.artifact_spectra = ArtifactSpectra(num_components=2) + + # normal sequencing error spectra for each variant type. + self.normal_seq_error_spectra = torch.nn.ModuleList([NormalSeqErrorSpectrum(num_samples=50, max_mean=0.001) for _ in Variation]) + + self.normal_artifact_spectra = initialize_normal_artifact_spectra() + + # pre-softmax priors of different call types [log P(variant), log P(artifact), log P(seq error)] for each variant type + self._unnormalized_priors_vc = torch.nn.Parameter(torch.ones(len(Variation), len(Call))) + with torch.no_grad(): + self._unnormalized_priors_vc[:, Call.SOMATIC] = variant_log_prior + self._unnormalized_priors_vc[:, Call.ARTIFACT] = artifact_log_prior + self._unnormalized_priors_vc[:, Call.SEQ_ERROR] = 0 + self._unnormalized_priors_vc[:, Call.GERMLINE] = -9999 if self.no_germline_mode else 0 + self._unnormalized_priors_vc[:, Call.NORMAL_ARTIFACT] = artifact_log_prior + + self.to(device=self._device, dtype=self._dtype) + + def make_unnormalized_priors(self, variant_types_b: torch.IntTensor, allele_frequencies_1d: torch.Tensor) -> torch.Tensor: + result_bc = self._unnormalized_priors_vc[variant_types_b.long(), :].to(device=self._device, dtype=self._dtype) + result_bc[:, Call.SEQ_ERROR] = 0 + result_bc[:, Call.GERMLINE] = -9999 if self.no_germline_mode else torch.log(1 - torch.square(1 - allele_frequencies_1d)) # 1 minus hom ref probability + return result_bc # batch size x len(CallType) + + def posterior_probabilities(self, batch: PosteriorBatch) -> torch.Tensor: + """ + :param batch: + :return: non-log probabilities as a 2D tensor, 1st index is batch, 2nd is variant/artifact/seq error + """ + return torch.nn.functional.softmax(self.log_relative_posteriors(batch), dim=1) + + def error_probabilities(self, batch: PosteriorBatch, germline_mode: bool = False) -> torch.Tensor: + """ + :param germline_mode: if True, germline classification is not considered an error mode + :param batch: + :return: non-log error probabilities as a 1D tensor with length batch size + """ + assert not (germline_mode and self.no_germline_mode), "germline mode and no-germline mode are incompatible" + return 1 - self.posterior_probabilities(batch)[:, Call.GERMLINE if germline_mode else Call.SOMATIC] # 0th column is variant + + def log_posterior_and_ingredients(self, batch: PosteriorBatch) -> torch.Tensor: + """ + :param batch: + :batch.seq_error_log_likelihoods() is the probability that these *particular* reads exhibit the alt allele given a + sequencing error ie an error explainable in terms of base qualities. For example if we have two alt reads with error + probability of 0.1 and 0.2, and two ref reads with error probabilities 0.05 and 0.06 this quantity would be + log(0.1*0.2*0.95*0.94). This is an annotation emitted by the GATK and by the time it reaches here is a 1D tensor + of length batch_size. + :return: + """ + variant_types = batch.get_variant_types().to(device=self._device, dtype=self._dtype) + + # All log likelihood/relative posterior tensors below have shape batch.size() x len(CallType) + # spectra tensors contain the likelihood that these *particular* reads (that is, not just the read count) are alt + # normal log likelihoods contain everything going on in the matched normal sample + # note that the call to make_unnormalized_priors ensures that no_germline_mode works + log_priors = torch.nn.functional.log_softmax(self.make_unnormalized_priors(variant_types, batch.get_allele_frequencies()), dim=1) + + # defined as log [ int_0^1 Binom(alt count | depth, f) df ], including the combinatorial N choose N_alt factor + depths, alt_counts = batch.get_depths(), batch.get_alt_counts() + normal_depths, normal_alt_counts = batch.get_normal_depths(), batch.get_normal_alt_counts() + flat_prior_spectra_log_likelihoods = -torch.log(depths + 1) + somatic_spectrum_log_likelihoods = self.somatic_spectrum.forward(depths, alt_counts) + tumor_artifact_spectrum_log_likelihood = self.artifact_spectra.forward(batch.get_variant_types(), depths, alt_counts) + spectra_log_likelihoods = torch.zeros_like(log_priors, device=self._device, dtype=self._dtype) + + # essentially, this corrects the TLOD from M2, computed with a flat prior, to account for the precises somatic spectrum + spectra_log_likelihoods[:, Call.SOMATIC] = somatic_spectrum_log_likelihoods - flat_prior_spectra_log_likelihoods + spectra_log_likelihoods[:, Call.ARTIFACT] = tumor_artifact_spectrum_log_likelihood - flat_prior_spectra_log_likelihoods + spectra_log_likelihoods[:, Call.NORMAL_ARTIFACT] = tumor_artifact_spectrum_log_likelihood - flat_prior_spectra_log_likelihoods # yup, it's the same spectrum + spectra_log_likelihoods[:, Call.SEQ_ERROR] = -batch.get_tlods_from_m2() + # spectra_log_likelihoods[:, Call.GERMLINE] is computed below + + normal_log_likelihoods = torch.zeros_like(log_priors) + normal_seq_error_log_likelihoods = torch.zeros_like(alt_counts).float() + + for var_index, _ in enumerate(Variation): + mask = (variant_types == var_index) + log_likelihoods_for_this_type = self.normal_seq_error_spectra[var_index].forward(normal_alt_counts, batch.get_normal_ref_counts()) + normal_seq_error_log_likelihoods += mask * log_likelihoods_for_this_type + + normal_log_likelihoods[:, Call.SOMATIC] = normal_seq_error_log_likelihoods + normal_log_likelihoods[:, Call.ARTIFACT] = normal_seq_error_log_likelihoods + normal_log_likelihoods[:, Call.SEQ_ERROR] = normal_seq_error_log_likelihoods + + no_alt_in_normal_mask = normal_alt_counts < 1 + normal_log_likelihoods[:, Call.NORMAL_ARTIFACT] = -9999 * no_alt_in_normal_mask + \ + torch.logical_not(no_alt_in_normal_mask) * self.normal_artifact_spectra.forward(variant_types, normal_depths, normal_alt_counts) + + afs = batch.get_allele_frequencies() + spectra_log_likelihoods[:, Call.GERMLINE] = germline_log_likelihood(afs, batch.get_mafs(), alt_counts, depths, self.het_beta) - flat_prior_spectra_log_likelihoods + + # it is correct not to subtract the flat prior likelihood from the normal term because this is an absolute likelihood, not + # relative to seq error as the M2 TLOD is defined + normal_log_likelihoods[:, Call.GERMLINE] = germline_log_likelihood(afs, batch.get_normal_mafs(), normal_alt_counts, normal_depths, self.het_beta) + + log_posteriors = log_priors + spectra_log_likelihoods + normal_log_likelihoods + log_posteriors[:, Call.ARTIFACT] += batch.get_artifact_logits() + + log_posteriors[:, Call.NORMAL_ARTIFACT] += batch.get_artifact_logits() + + return log_priors, spectra_log_likelihoods, normal_log_likelihoods, log_posteriors + + def log_relative_posteriors(self, batch: PosteriorBatch) -> torch.Tensor: + _, _, _, log_posteriors = self.log_posterior_and_ingredients(batch) + return log_posteriors + + def learn_priors_and_spectra(self, posterior_loader, num_iterations, ignored_to_non_ignored_ratio: float, + summary_writer: SummaryWriter = None, learning_rate: float = 0.001): + """ + :param summary_writer: + :param num_iterations: + :param posterior_loader: + :param ignored_to_non_ignored_ratio: ratio of sites in which no evidence of variation was found to sites in which + sufficient evidence was found to emit test data. Without this parameter (i.e. if it were set to zero) we would + underestimate the frequency of sequencing error, hence overestimate the prior probability of variation. + :param artifact_spectra_state_dict: (possibly None) if given, pretrained parameters of self.artifact_spectra + from train_model.py. In this case we make sure to freeze this part of the model + :param artifact_log_priors: (possibly None) 1D tensor with length len(utils.Variation) containing log prior probabilities + of artifacts for each variation type, from train_model.py. If given, freeze these parameters. + :return: + """ + spectra_and_prior_params = chain(self.somatic_spectrum.parameters(), self.artifact_spectra.parameters(), + [self._unnormalized_priors_vc], self.normal_seq_error_spectra.parameters(), + self.normal_artifact_spectra.parameters()) + optimizer = torch.optim.Adam(spectra_and_prior_params, lr=learning_rate) + + for epoch in trange(1, num_iterations + 1, desc="AF spectra epoch"): + epoch_loss = utils.StreamingAverage() + + # store posteriors as a list (to be stacked at the end of the epoch) for an M step + # 'l' for loader, 'b' for batch, 'c' for call type + posteriors_lbc = [] + alt_counts_lb = [] + depths_lb = [] + types_lb = [] + + pbar = tqdm(enumerate(posterior_loader), mininterval=10) + batch_cpu: PosteriorBatch + for n, batch_cpu in pbar: + batch = batch_cpu.copy_to(self._device, self._dtype, non_blocking=self._device.type == 'cuda') + relative_posteriors = self.log_relative_posteriors(batch) + log_evidence = torch.logsumexp(relative_posteriors, dim=1) + + posteriors_lbc.append(torch.softmax(relative_posteriors, dim=-1).detach()) + alt_counts_lb.append(batch.get_alt_counts().detach()) + depths_lb.append(batch.get_depths().detach()) + types_lb.append(batch.get_variant_types().detach()) + + confidence_mask = torch.abs(batch.get_artifact_logits()) > 3.0 + #loss = -torch.mean(confidence_mask * log_evidence) + loss = - torch.sum(confidence_mask * log_evidence) / (torch.sum(confidence_mask) + 0.000001) + + # note that we don't multiply by batch size because we take the mean of log evidence above + # however, we must sum over variant types since each ignored site is simultaneously a missing non-SNV, + # a missing non-INSERTION etc + # we use a germline allele frequency of 0.001 for the missing sites but it doesn't really matter + for var_type_idx, variant_type in enumerate(Variation): + log_priors = torch.nn.functional.log_softmax(self.make_unnormalized_priors(torch.LongTensor([var_type_idx]).to(device=self._device, dtype=self._dtype), torch.tensor([0.001], device=self._device)), dim=1) + log_seq_error_prior = log_priors.squeeze()[Call.SEQ_ERROR] + missing_loss = -ignored_to_non_ignored_ratio * log_seq_error_prior + loss += missing_loss + + utils.backpropagate(optimizer, loss) + + epoch_loss.record_sum(batch.size() * loss.detach().item(), batch.size()) + # iteration over posterior dataloader finished + + # 'n' denotes index of data within entire Posterior Dataset + posteriors_nc = torch.vstack(posteriors_lbc) + alt_counts_n = torch.hstack(alt_counts_lb) + depths_n = torch.hstack(depths_lb) + types_n = torch.hstack(types_lb) + + self.update_priors_m_step(posteriors_nc, types_n, ignored_to_non_ignored_ratio) + self.somatic_spectrum.update_m_step(posteriors_nc[:, Call.SOMATIC], alt_counts_n, depths_n) + + if summary_writer is not None: + summary_writer.add_scalar("spectrum negative log evidence", epoch_loss.get(), epoch) + + for variant_index, variant_type in enumerate(Variation): + mean = self.normal_seq_error_spectra[variant_index].get_mean() + summary_writer.add_scalar("normal seq error mean fraction for " + variant_type.name, mean, epoch) + + for depth in [10, 20, 30, 50, 100]: + art_spectra_fig, art_spectra_axs = plot_artifact_spectra(self.artifact_spectra, depth) + summary_writer.add_figure("Artifact AF Spectra at depth = " + str(depth), art_spectra_fig, epoch) + + #normal_seq_error_spectra_fig, normal_seq_error_spectra_axs = plot_artifact_spectra(self.normal_seq_error_spectra) + #summary_writer.add_figure("Normal Seq Error AF Spectra", normal_seq_error_spectra_fig, epoch) + + normal_artifact_spectra_fig, normal_artifact_spectra_axs = plot_artifact_spectra(self.normal_artifact_spectra) + summary_writer.add_figure("Normal Artifact AF Spectra", normal_artifact_spectra_fig, epoch) + + var_spectra_fig, var_spectra_axs = plt.subplots() + frac, dens = self.somatic_spectrum.spectrum_density_vs_fraction() + var_spectra_axs.plot(frac.detach().numpy(), dens.detach().numpy(), label="spectrum") + var_spectra_axs.set_title("Variant AF Spectrum") + summary_writer.add_figure("Variant AF Spectra", var_spectra_fig, epoch) + + # bar plot of log priors -- data is indexed by call type name, and x ticks are variant types + log_prior_bar_plot_data = defaultdict(list) + for var_type_idx, variant_type in enumerate(Variation): + log_priors = torch.nn.functional.log_softmax(self.make_unnormalized_priors(torch.LongTensor([var_type_idx]).to(device=self._device, dtype=self._dtype), torch.Tensor([0.001])), dim=-1) + log_priors_cpu = log_priors.squeeze().detach().cpu() + for call_type in (Call.SOMATIC, Call.ARTIFACT, Call.NORMAL_ARTIFACT): + log_prior_bar_plot_data[call_type.name].append(log_priors_cpu[call_type]) + + prior_fig, prior_ax = plotting.grouped_bar_plot(log_prior_bar_plot_data, [v_type.name for v_type in Variation], "log priors") + summary_writer.add_figure("log priors", prior_fig, epoch) + + # normal artifact joint tumor-normal spectra + # na_fig, na_axes = plt.subplots(1, len(Variation), sharex='all', sharey='all', squeeze=False) + # for variant_index, variant_type in enumerate(Variation): + # self.normal_artifact_spectra[variant_index].density_plot_on_axis(na_axes[0, variant_index]) + # plotting.tidy_subplots(na_fig, na_axes, x_label="tumor fraction", y_label="normal fraction", + # row_labels=[""], column_labels=[var_type.name for var_type in Variation]) + # summary_writer.add_figure("normal artifact spectra", na_fig, epoch) + + # map of Variant type to probability threshold that maximizes F1 score + # loader is a Dataloader whose collate_fn is the PosteriorBatch constructor + def calculate_probability_thresholds(self, loader, summary_writer: SummaryWriter = None, germline_mode: bool = False): + self.train(False) + error_probs_by_type = {var_type: [] for var_type in Variation} # includes both artifact and seq errors + + error_probs_by_type_by_cnt = {var_type: [[] for _ in range(NUM_COUNT_BINS)] for var_type in Variation} + + pbar = tqdm(enumerate(loader), mininterval=10) + for n, batch_cpu in pbar: + batch = batch_cpu.copy_to(self._device, self._dtype, non_blocking=self._device.type == 'cuda') + alt_counts = batch_cpu.get_alt_counts().tolist() + # 0th column is true variant, subtract it from 1 to get error prob + error_probs = self.error_probabilities(batch, germline_mode).cpu().tolist() + + for var_type, alt_count, error_prob in zip(batch_cpu.get_variant_types().tolist(), alt_counts, error_probs): + error_probs_by_type[var_type].append(error_prob) + error_probs_by_type_by_cnt[var_type][multiple_of_three_bin_index(min(alt_count, MAX_COUNT))].append(error_prob) + + thresholds_by_type = {} + roc_fig, roc_axes = plt.subplots(1, len(Variation), sharex='all', sharey='all', squeeze=False) + roc_by_cnt_fig, roc_by_cnt_axes = plt.subplots(1, len(Variation), sharex='all', sharey='all', squeeze=False, figsize=(10, 6), dpi=100) + for var_type in Variation: + # plot all count ROC curves for this variant type + count_bin_labels = [str(multiple_of_three_bin_index_to_count(count_bin)) for count_bin in range(NUM_COUNT_BINS)] + _ = plotting.plot_theoretical_roc_on_axis(error_probs_by_type_by_cnt[var_type], count_bin_labels, roc_by_cnt_axes[0, var_type]) + best_threshold = plotting.plot_theoretical_roc_on_axis([error_probs_by_type[var_type]], [""], roc_axes[0, var_type])[0][0] + + # TODO: the theoretical ROC might need to return the best threshold for this + thresholds_by_type[var_type] = best_threshold + + variation_types = [var_type.name for var_type in Variation] + plotting.tidy_subplots(roc_by_cnt_fig, roc_by_cnt_axes, x_label="sensitivity", y_label="precision", + row_labels=[""], column_labels=variation_types) + plotting.tidy_subplots(roc_fig, roc_axes, x_label="sensitivity", y_label="precision", + row_labels=[""], column_labels=variation_types) + if summary_writer is not None: + summary_writer.add_figure("theoretical ROC by variant type ", roc_fig) + summary_writer.add_figure("theoretical ROC by variant type and alt count ", roc_by_cnt_fig) + + return thresholds_by_type + + def update_priors_m_step(self, posteriors_nc, types_n, ignored_to_non_ignored_ratio): + # update the priors in an EM-style M step. We'll need the counts of each call type vs variant type + total_nonignored = torch.sum(posteriors_nc).item() + total_ignored = ignored_to_non_ignored_ratio * total_nonignored + overall_total = total_ignored + total_nonignored + + with torch.no_grad(): + for c, call_type in enumerate(Call): + if call_type == Call.SEQ_ERROR or call_type == Call.GERMLINE: + continue + posteriors_n = posteriors_nc[:, c] + + for t, var_type in enumerate(Variation): + var_type_mask = (types_n == t) + total_for_this_call_and_var_type = torch.sum(posteriors_n * var_type_mask) + + self._unnormalized_priors_vc[t, c] = torch.log( + total_for_this_call_and_var_type / (total_for_this_call_and_var_type + overall_total)).item() + + self._unnormalized_priors_vc[:, Call.SEQ_ERROR] = 0 + self._unnormalized_priors_vc[:, Call.GERMLINE] = -9999 if self.no_germline_mode else 0 diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/architecture/somatic_spectrum.py b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/somatic_spectrum.py new file mode 100644 index 00000000000..ddd2dd507ea --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/architecture/somatic_spectrum.py @@ -0,0 +1,149 @@ +import math + +import torch +from permutect import utils +from torch import nn +from torch.nn.functional import log_softmax + +from permutect.metrics.plotting import simple_plot +from permutect.utils import beta_binomial, binomial + +# exclude obvious germline, artifact, sequencing error etc from M step for speed +MIN_POSTERIOR_FOR_M_STEP = 0.2 + + +class SomaticSpectrum(nn.Module): + """ + This model takes in 1D tensor (batch size, ) alt counts and depths and computes the log likelihoods + log P(alt count | depth, spectrum parameters). + + The probability P(alt count | depth) is a K-component mixture model where K-1 components are simple binomials + P_k(a|d) = Binom(a|d, f_k) = (d C a) f_k^a (1-f_k)^(d-a), where f_k is the allele fraction associated with component + k and the Kth component is a background beta binomial P_K(a|d) = integral{Beta(f|alpha, beta) * Binom(a|d, f) df}. + + This integral is exact and is implemented in utils.beta_binomial() + + We compute the binomial and beta binomial log likelihoods, then add in log space via logsumexp to get the overall + mixture log likelihood. + """ + + def __init__(self, num_components: int): + super(SomaticSpectrum, self).__init__() + self.K = num_components + + # initialize equal weights for each binomial component and larger weight for beta binomial background (last component) + weights_pre_softmax = torch.ones(self.K) + weights_pre_softmax[-1] = 3 + + self.weights_pre_softmax_k = torch.nn.Parameter(weights_pre_softmax) + + # initialize evenly spaced pre-sigmoid from -2 to 2 + self.f_pre_sigmoid_k = torch.nn.Parameter((4 * torch.arange(self.K - 1) / (self.K - 1)) - 2) + + # the alpha, beta shape parameters are exponentiated in the forward pass to ensure positive values + self.alpha_pre_exp = torch.nn.Parameter(torch.tensor(1.0)) + self.beta_pre_exp = torch.nn.Parameter(torch.tensor(1.0)) + + ''' + here alt counts and depths are 1D (batch size, ) tensors + ''' + def forward(self, depths_b, alt_counts_b): + weighted_likelihoods_bk = self.weighted_likelihoods_by_cluster(depths_b, alt_counts_b) + result_b = torch.logsumexp(weighted_likelihoods_bk, dim=1, keepdim=False) + return result_b + + + ''' + here alt counts and depths are 1D (batch size, ) tensors + ''' + def weighted_likelihoods_by_cluster(self, depths_b, alt_counts_b): + batch_size = len(alt_counts_b) + + f_k = torch.sigmoid(self.f_pre_sigmoid_k) + f_bk = f_k.expand(batch_size, -1) + alt_counts_bk = torch.unsqueeze(alt_counts_b, dim=1).expand(-1, self.K - 1) + depths_bk = torch.unsqueeze(depths_b, dim=1).expand(-1, self.K - 1) + binomial_likelihoods_bk = binomial(depths_bk, alt_counts_bk, f_bk) + + alpha = torch.exp(self.alpha_pre_exp) + beta = torch.exp(self.beta_pre_exp) + alpha_b = alpha.expand(batch_size) + beta_b = beta.expand(batch_size) + + beta_binomial_likelihoods_b = beta_binomial(depths_b, alt_counts_b, alpha_b, beta_b) + beta_binomial_likelihoods_bk = torch.unsqueeze(beta_binomial_likelihoods_b, dim=1) + + likelihoods_bk = torch.hstack((binomial_likelihoods_bk, beta_binomial_likelihoods_bk)) + + log_weights_k = log_softmax(self.weights_pre_softmax_k, dim=-1) # these weights are normalized + log_weights_bk = log_weights_k.expand(batch_size, -1) + weighted_likelihoods_bk = log_weights_bk + likelihoods_bk + + return weighted_likelihoods_bk + + # posteriors: responsibilities that each object is somatic + def update_m_step(self, posteriors_n, alt_counts_n, depths_n): + possible_somatic_indices = posteriors_n > MIN_POSTERIOR_FOR_M_STEP + somatic_posteriors_n = posteriors_n[possible_somatic_indices] + somatic_alt_counts_n = alt_counts_n[possible_somatic_indices] + somatic_depths_n = depths_n[possible_somatic_indices] + + # TODO: make sure this all fits on GPU + # TODO: maybe split it up into batches? + weighted_likelihoods_nk = self.weighted_likelihoods_by_cluster(somatic_depths_n, somatic_alt_counts_n) + cluster_posteriors_nk = somatic_posteriors_n[:, None] * torch.softmax(weighted_likelihoods_nk, dim=-1) + cluster_totals_k = torch.sum(cluster_posteriors_nk, dim=0) + + with torch.no_grad(): + self.weights_pre_softmax_k.copy_(torch.log(cluster_totals_k + 0.00001)) + + # update the binomial clusters -- we exclude the last cluster, which is beta binomial + for k in range(self.K - 1): + weights = cluster_posteriors_nk[:, k] + f = torch.sum((weights * somatic_alt_counts_n)) / torch.sum((0.00001 + weights * somatic_depths_n)) + + self.f_pre_sigmoid_k[k] = torch.log(f / (1-f)) + + def fit(self, num_epochs, depths_1d_tensor, alt_counts_1d_tensor, batch_size=64): + optimizer = torch.optim.Adam(self.parameters()) + num_batches = math.ceil(len(alt_counts_1d_tensor) / batch_size) + + for epoch in range(num_epochs): + for batch in range(num_batches): + batch_start = batch * batch_size + batch_end = min(batch_start + batch_size, len(alt_counts_1d_tensor)) + batch_slice = slice(batch_start, batch_end) + loss = -torch.mean(self.forward(depths_1d_tensor[batch_slice], alt_counts_1d_tensor[batch_slice])) + utils.backpropagate(optimizer, loss) + + ''' + get raw data for a spectrum plot of probability density vs allele fraction + ''' + def spectrum_density_vs_fraction(self): + fractions_f = torch.arange(0.01, 0.99, 0.001) # 1D tensor + + f_k = torch.sigmoid(self.f_pre_sigmoid_k).cpu() + + # smear each binomial f into a narrow Gaussian for plotting + gauss_k = torch.distributions.normal.Normal(f_k, 0.01 * torch.ones_like(f_k)) + log_gauss_fk = gauss_k.log_prob(fractions_f.unsqueeze(dim=1)) + + alpha = torch.exp(self.alpha_pre_exp).cpu() + beta = torch.exp(self.beta_pre_exp).cpu() + + beta = torch.distributions.beta.Beta(alpha, beta) + log_beta_fk = beta.log_prob(fractions_f.unsqueeze(dim=1)) + + log_densities_fk = torch.hstack((log_gauss_fk, log_beta_fk)) + + log_weights_k = log_softmax(self.weights_pre_softmax_k, dim=-1).cpu() # these weights are normalized + log_weights_fk = log_weights_k.expand(len(fractions_f), -1) + + log_weighted_densities_fk = log_weights_fk + log_densities_fk + densities_f = torch.exp(torch.logsumexp(log_weighted_densities_fk, dim=1, keepdim=False)) + + return fractions_f, densities_f + + def plot_spectrum(self, title): + fractions, densities = self.spectrum_density_vs_fraction() + return simple_plot([(fractions.numpy(), densities.numpy(), " ")], "AF", "density", title) \ No newline at end of file diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/constants.py b/src/main/python/org/broadinstitute/hellbender/permutect/constants.py new file mode 100644 index 00000000000..d81e848e298 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/constants.py @@ -0,0 +1,66 @@ +STATE_DICT_NAME = 'model_state_dict' +ARTIFACT_LOG_PRIORS_NAME = 'artifact_log_priors' +ARTIFACT_SPECTRA_STATE_DICT_NAME = 'artifact_spectra_state_dict' +HYPERPARAMS_NAME = 'hyperparams' +NUM_READ_FEATURES_NAME = 'num_read_features' +NUM_INFO_FEATURES_NAME = 'num_info_features' +REF_SEQUENCE_LENGTH_NAME = 'ref_sequence_length' +HIDDEN_LAYERS_NAME = 'hidden_layers' +NUM_BASE_FEATURES_NAME = 'num_base_features' +NUM_REF_ALT_FEATURES_NAME = 'num_ref_alt_features' + +SOURCES_NAME = 'sources' +SOURCE_NAME = 'source' + +INPUT_NAME = 'input' +OUTPUT_NAME = 'output' +OUTPUT_DIR_NAME = 'output_dir' + +READ_LAYERS_NAME = 'read_layers' +SELF_ATTENTION_HIDDEN_DIMENSION_NAME = 'self_attention_hidden_dimension' +NUM_SELF_ATTENTION_LAYERS_NAME = 'num_self_attention_layers' + +LEARNING_METHOD_NAME = 'learning_method' + +INFO_LAYERS_NAME = 'info_layers' +AGGREGATION_LAYERS_NAME = 'aggregation_layers' +CALIBRATION_LAYERS_NAME = 'calibration_layers' +REF_SEQ_LAYER_STRINGS_NAME = 'ref_seq_layer_strings' +DROPOUT_P_NAME = 'dropout_p' +LEARNING_RATE_NAME = 'learning_rate' +WEIGHT_DECAY_NAME = 'weight_decay' +BATCH_NORMALIZE_NAME = 'batch_normalize' +LEARN_ARTIFACT_SPECTRA_NAME = 'learn_artifact_spectra' + +TRAINING_DATASETS_NAME = 'training_datasets' +TRAIN_TAR_NAME = 'train_tar' +EVALUATION_TAR_NAME = 'evaluation_tar' +TEST_DATASET_NAME = 'test_dataset' +NORMAL_ARTIFACT_DATASETS_NAME = 'normal_artifact_datasets' +REWEIGHTING_RANGE_NAME = 'reweighting_range' +BATCH_SIZE_NAME = 'batch_size' +CHUNK_SIZE_NAME = 'chunk_size' +NUM_EPOCHS_NAME = 'num_epochs' +NUM_CALIBRATION_EPOCHS_NAME = 'num_calibration_epochs' +INFERENCE_BATCH_SIZE_NAME = 'inference_batch_size' +NUM_WORKERS_NAME = 'num_workers' +NUM_SPECTRUM_ITERATIONS_NAME = 'num_spectrum_iterations' +SPECTRUM_LEARNING_RATE_NAME = 'spectrum_learning_rate' + +DATASET_EDIT_TYPE_NAME = 'dataset_edit' + +TENSORBOARD_DIR_NAME = 'tensorboard_dir' + +INITIAL_LOG_VARIANT_PRIOR_NAME = 'initial_log_variant_prior' +INITIAL_LOG_ARTIFACT_PRIOR_NAME = 'initial_log_artifact_prior' +CONTIGS_TABLE_NAME = 'contigs_table' +GENOMIC_SPAN_NAME = 'genomic_span' +MAF_SEGMENTS_NAME = 'maf_segments' +NORMAL_MAF_SEGMENTS_NAME = 'normal_maf_segments' +GERMLINE_MODE_NAME = 'germline_mode' +NO_GERMLINE_MODE_NAME = 'no_germline_mode' +HET_BETA_NAME = 'het_beta' + +BASE_MODEL_NAME = 'base_model' +M3_MODEL_NAME = 'permutect_model' +PRETRAINED_MODEL_NAME = 'pretrained_model' diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/data/__init__.py b/src/main/python/org/broadinstitute/hellbender/permutect/data/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/data/artifact_dataset.py b/src/main/python/org/broadinstitute/hellbender/permutect/data/artifact_dataset.py new file mode 100644 index 00000000000..d3e6df0fc84 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/data/artifact_dataset.py @@ -0,0 +1,109 @@ +import math +import random +from typing import List + +import torch +from tqdm.autonotebook import tqdm +from torch.utils.data import Dataset, DataLoader, Sampler + +from permutect.architecture.base_model import BaseModel +from permutect.data.base_datum import ArtifactDatum, ArtifactBatch +from permutect.data.base_dataset import BaseDataset, chunk + + +# given a ReadSetDataset, apply a BaseModel to get an ArtifactDataset (in RAM, maybe implement memory map later) +# of RepresentationReadSets +class ArtifactDataset(Dataset): + def __init__(self, base_dataset: BaseDataset, + base_model: BaseModel, + folds_to_use: List[int] = None, + base_loader_num_workers=0, + base_loader_batch_size=8192): + self.counts_by_source = base_dataset.counts_by_source + self.totals = base_dataset.totals + self.source_totals = base_dataset.source_totals + self.weights = base_dataset.weights + self.source_weights = base_dataset.source_weights + + self.artifact_data = [] + self.num_folds = base_dataset.num_folds + self.labeled_indices = [[] for _ in range(self.num_folds)] # one list for each fold + self.unlabeled_indices = [[] for _ in range(self.num_folds)] # ditto + self.num_base_features = base_model.output_dimension() + self.num_ref_alt_features = base_model.ref_alt_seq_embedding_dimension() + + index = 0 + + loader = base_dataset.make_data_loader(base_dataset.all_folds() if folds_to_use is None else folds_to_use, + batch_size=base_loader_batch_size, + num_workers=base_loader_num_workers) + print("making artifact dataset from base dataset") + + is_cuda = base_model._device.type == 'cuda' + print(f"Is base model using CUDA? {is_cuda}") + + pbar = tqdm(enumerate(loader), mininterval=60) + for n, base_batch_cpu in pbar: + base_batch = base_batch_cpu.copy_to(base_model._device, non_blocking=is_cuda) + with torch.inference_mode(): + representations, _ = base_model.calculate_representations(base_batch) + + for representation, base_datum in zip(representations.detach().cpu(), base_batch_cpu.original_list()): + artifact_datum = ArtifactDatum(base_datum, representation.detach()) + self.artifact_data.append(artifact_datum) + fold = index % self.num_folds + if artifact_datum.is_labeled(): + self.labeled_indices[fold].append(index) + else: + self.unlabeled_indices[fold].append(index) + index += 1 + + def __len__(self): + return len(self.artifact_data) + + def __getitem__(self, index): + return self.artifact_data[index] + + # it is often convenient to arbitrarily use the last fold for validation + def last_fold_only(self): + return [self.num_folds - 1] # use the last fold for validation + + def all_but_the_last_fold(self): + return list(range(self.num_folds - 1)) + + def all_but_one_fold(self, fold_to_exclude: int): + return list(range(fold_to_exclude)) + list(range(fold_to_exclude + 1, self.num_folds)) + + def all_folds(self): + return list(range(self.num_folds)) + + def make_data_loader(self, folds_to_use: List[int], batch_size: int, pin_memory=False, num_workers: int = 0, labeled_only: bool = False): + sampler = SemiSupervisedArtifactBatchSampler(self, batch_size, folds_to_use, labeled_only) + return DataLoader(dataset=self, batch_sampler=sampler, collate_fn=ArtifactBatch, pin_memory=pin_memory, num_workers=num_workers) + + +# make ArtifactBatches that mix different ref, alt counts, labeled, unlabeled +# with an option to emit only labeled data +class SemiSupervisedArtifactBatchSampler(Sampler): + def __init__(self, dataset: ArtifactDataset, batch_size, folds_to_use: List[int], labeled_only: bool = False): + # combine the index lists of all relevant folds + self.indices_to_use = [] + + for fold in folds_to_use: + self.indices_to_use.extend(dataset.labeled_indices[fold]) + if not labeled_only: + self.indices_to_use.extend(dataset.unlabeled_indices[fold]) + + self.batch_size = batch_size + self.num_batches = math.ceil(len(self.indices_to_use) // self.batch_size) + + def __iter__(self): + random.shuffle(self.indices_to_use) + batches = chunk(self.indices_to_use, self.batch_size) # list of lists of indices -- each sublist is a batch + random.shuffle(batches) + + return iter(batches) + + def __len__(self): + return self.num_batches + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/data/base_dataset.py b/src/main/python/org/broadinstitute/hellbender/permutect/data/base_dataset.py new file mode 100644 index 00000000000..88d23e34251 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/data/base_dataset.py @@ -0,0 +1,222 @@ +import math +import os +import psutil +import random +import tarfile +import tempfile +from collections import defaultdict +from itertools import chain +from typing import Iterable, List + +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.sampler import Sampler + +from mmap_ninja.ragged import RaggedMmap +from permutect import utils +from permutect.data.base_datum import BaseDatum, BaseBatch, load_list_of_base_data, OneDimensionalData +from permutect.utils import Label, MutableInt + +TENSORS_PER_BASE_DATUM = 2 # 1) 2D reads (ref and alt), 1) 1D concatenated stuff + +# tarfiles on disk take up about 4x as much as the dataset on RAM +TARFILE_TO_RAM_RATIO = 4 + +ALL_COUNTS_SENTINEL = 0 + +WEIGHT_PSEUDOCOUNT = 10 + + +def ratio_with_pseudocount(a, b): + return (a + WEIGHT_PSEUDOCOUNT) / (b + WEIGHT_PSEUDOCOUNT) + + +class BaseDataset(Dataset): + def __init__(self, data_in_ram: Iterable[BaseDatum] = None, data_tarfile=None, num_folds: int = 1): + super(BaseDataset, self).__init__() + assert data_in_ram is not None or data_tarfile is not None, "No data given" + assert data_in_ram is None or data_tarfile is None, "Data given from both RAM and tarfile" + self.num_folds = num_folds + + if data_in_ram is not None: + self._data = data_in_ram + self._memory_map_mode = False + else: + tarfile_size = os.path.getsize(data_tarfile) # in bytes + estimated_data_size_in_ram = tarfile_size // TARFILE_TO_RAM_RATIO + available_memory = psutil.virtual_memory().available + fits_in_ram = estimated_data_size_in_ram < 0.8 * available_memory + + print(f"The tarfile size is {tarfile_size} bytes on disk for an estimated {estimated_data_size_in_ram} bytes in memory and the system has {available_memory} bytes of RAM available.") + if fits_in_ram: + print("loading the dataset from the tarfile into RAM:") + self._data = list(make_base_data_generator_from_tarfile(data_tarfile)) + self._memory_map_mode = False + else: + print("loading the dataset into a memory-mapped file:") + self._memory_map_dir = tempfile.TemporaryDirectory() + + RaggedMmap.from_generator(out_dir=self._memory_map_dir.name, + sample_generator=make_flattened_tensor_generator( + make_base_data_generator_from_tarfile(data_tarfile)), + batch_size=10000, verbose=False) + self._data = RaggedMmap(self._memory_map_dir.name) + self._memory_map_mode = True + + # this is used in the batch sampler to make same-shape batches + self.indices_by_fold = [[] for _ in range(num_folds)] + + # totals by count, then by label -- ARTIFACT, VARIANT, UNLABELED, then by variant type + # variant type is done as a 1D np array parallel to the one-hot encoding of variant type + # we use a sentinel count value of 0 to denote aggregation over all counts + # eg totals[4][Label.ARTIFACT] = [2,4,6,8,10] means there are 2 artifact SNVs with alt count 4 + self.totals = defaultdict(lambda: {label: np.zeros(len(utils.Variation)) for label in Label}) + + # totals by count, then by source (integer) then by variant type + # basically same as above but with source instead of label. Since we don't know a priori how + # many sources there are, we use a default dict + # outer default dict is count, inner is source + self.source_totals = defaultdict(lambda: defaultdict(lambda: np.zeros(len(utils.Variation)))) + + self.counts_by_source = defaultdict(lambda: MutableInt()) # amount of data for each source (which is an integer key) + + for n, datum in enumerate(self): + self.counts_by_source[datum.source].increment() + + fold = n % num_folds + self.indices_by_fold[fold].append(n) + + variant_type_idx = datum.get_variant_type() + self.totals[ALL_COUNTS_SENTINEL][datum.label][variant_type_idx] += 1 + self.totals[datum.alt_count][datum.label][variant_type_idx] += 1 + self.source_totals[ALL_COUNTS_SENTINEL][datum.source][variant_type_idx] += 1 + self.source_totals[datum.alt_count][datum.source][variant_type_idx] += 1 + + + # compute weights to balance loss even for unbalanced data + # in the weights array, count == 0 (which never occurs as a real alt count) is the sentinel value for + # aggregation over all alt counts. The array is indexed by count, then label, then variation type + max_count = max(self.totals.keys()) + self.weights = np.zeros((max_count + 1, len(Label), len(utils.Variation))) + + # similar but indexed by count, then source, then variant type + max_source = max(self.source_totals[ALL_COUNTS_SENTINEL].keys()) + self.source_weights = np.zeros((max_count + 1, max_source + 1, len(utils.Variation))) + + sources = self.source_totals[ALL_COUNTS_SENTINEL].keys() + for count in self.totals.keys(): + # eg: if there are 1000 artifact and 10 non-artifact SNVs, the ratio is 100, and artifacts get a weight of 1/sqrt(100) = 1/10 + # while non-artifacts get a weight of 10 -- hence the effective count of each is 1000/10 = 10*10 = 100 + art_to_nonart_ratios = ratio_with_pseudocount(self.totals[count][Label.ARTIFACT], self.totals[count][Label.VARIANT]) + self.weights[count][Label.VARIANT] = np.sqrt(art_to_nonart_ratios) + self.weights[count][Label.ARTIFACT] = 1 / np.sqrt(art_to_nonart_ratios) + + effective_labeled_counts = self.totals[count][Label.ARTIFACT] * self.weights[count][Label.ARTIFACT] + \ + self.totals[count][Label.VARIANT] * self.weights[count][Label.VARIANT] + + # unlabeled data are weighted down to have at most the same total weight as labeled data + # example, 1000 unlabeled SNVs and 100 labeled SNVs -- unlabeled weight is 100/1000 = 1/10 + # example, 10 unlabeled and 100 labeled -- unlabeled weight is 1 + self.weights[count][Label.UNLABELED] = np.clip(ratio_with_pseudocount(effective_labeled_counts, self.totals[count][Label.UNLABELED]), 0,1) + + # by variant type, for this count + totals_over_sources = np.sum([self.source_totals[count][source] for source in sources]) + for source in sources: + self.source_weights[count][source] = np.sqrt(ratio_with_pseudocount(totals_over_sources, self.source_weights[count][source])) + + # normalize source prediction weights to have same total effective count. Note that this is modulated + # downstream by set_alpha on the gradient reversal layer applied before source prediction + effective_source_counts = np.sum([self.source_totals[count][source] * self.source_weights[count][source] for source in sources]) + source_weight_normalization = effective_labeled_counts / effective_source_counts + for source in sources: + self.source_weights[count][source] = self.source_weights[count][source] * source_weight_normalization + + self.weights = torch.from_numpy(self.weights) + self.source_weights = torch.from_numpy(self.source_weights) + self.num_read_features = self[0].get_reads_2d().shape[1] + self.num_info_features = len(self[0].get_info_tensor_1d()) + self.ref_sequence_length = len(self[0].get_ref_sequence_1d()) + + def __len__(self): + return len(self._data) // TENSORS_PER_BASE_DATUM if self._memory_map_mode else len(self._data) + + def __getitem__(self, index): + if self._memory_map_mode: + bottom_index = index * TENSORS_PER_BASE_DATUM + other_stuff = OneDimensionalData.from_np_array(self._data[bottom_index + 1]) + + return BaseDatum(reads_2d=self._data[bottom_index], ref_sequence_1d=None, alt_count=None, info_array_1d=None, + variant_type=None, label=None, source=None, variant=None, counts_and_seq_lks=None, + one_dimensional_data_override=other_stuff) + else: + return self._data[index] + + # it is often convenient to arbitrarily use the last fold for validation + def last_fold_only(self): + return [self.num_folds - 1] # use the last fold for validation + + def all_but_the_last_fold(self): + return list(range(self.num_folds - 1)) + + def all_but_one_fold(self, fold_to_exclude: int): + return list(range(fold_to_exclude)) + list(range(fold_to_exclude + 1, self.num_folds)) + + def all_folds(self): + return list(range(self.num_folds)) + + def make_data_loader(self, folds_to_use: List[int], batch_size: int, pin_memory=False, num_workers: int = 0): + sampler = SemiSupervisedBatchSampler(self, batch_size, folds_to_use) + return DataLoader(dataset=self, batch_sampler=sampler, collate_fn=BaseBatch, pin_memory=pin_memory, num_workers=num_workers) + + +# from a generator that yields BaseDatum(s), create a generator that yields the two numpy arrays needed to reconstruct the datum +def make_flattened_tensor_generator(base_data_generator): + for base_datum in base_data_generator: + yield base_datum.get_reads_2d() + yield base_datum.get_1d_data().to_np_array() + + +def make_base_data_generator_from_tarfile(data_tarfile): + # extract the tarfile to a temporary directory that will be cleaned up when the program ends + temp_dir = tempfile.TemporaryDirectory() + tar = tarfile.open(data_tarfile) + tar.extractall(temp_dir.name) + tar.close() + data_files = [os.path.abspath(os.path.join(temp_dir.name, p)) for p in os.listdir(temp_dir.name)] + + for file in data_files: + for datum in load_list_of_base_data(file): + yield datum + + +# ex: chunk([a,b,c,d,e], 3) = [[a,b,c], [d,e]] +def chunk(lis, chunk_size): + return [lis[i:i + chunk_size] for i in range(0, len(lis), chunk_size)] + + +# Labeled and unlabeled data are mixed. +# the artifact model handles weighting the losses to compensate for class imbalance between supervised and unsupervised +# thus the sampler is not responsible for balancing the data +class SemiSupervisedBatchSampler(Sampler): + def __init__(self, dataset: BaseDataset, batch_size, folds_to_use: List[int]): + # combine the index maps of all relevant folds + self.indices_to_use = [] + + for fold in folds_to_use: + self.indices_to_use.extend(dataset.indices_by_fold[fold]) + + self.batch_size = batch_size + self.num_batches = math.ceil(len(self.indices_to_use) // self.batch_size) + + def __iter__(self): + batches = [] # list of lists of indices -- each sublist is a batch + random.shuffle(self.indices_to_use) + batches.extend(chunk(self.indices_to_use, self.batch_size)) + random.shuffle(batches) + + return iter(batches) + + def __len__(self): + return self.num_batches + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/data/base_datum.py b/src/main/python/org/broadinstitute/hellbender/permutect/data/base_datum.py new file mode 100644 index 00000000000..7c31bd08ed7 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/data/base_datum.py @@ -0,0 +1,734 @@ +import copy +import math + +import numpy as np +import torch +from torch import Tensor, IntTensor, FloatTensor + +from typing import List + +from permutect.utils import Variation, Label, trim_alleles_on_right + +DEFAULT_NUMPY_FLOAT = np.float16 +DEFAULT_GPU_FLOAT = torch.float32 +DEFAULT_CPU_FLOAT = torch.float32 + +# base strings longer than this when encoding data +MAX_NUM_BASES_FOR_ENCODING = 13 + +MAX_FLOAT_16 = torch.finfo(torch.float16).max +MIN_FLOAT_16 = torch.finfo(torch.float16).min + + +def make_1d_sequence_tensor(sequence_string: str) -> np.ndarray: + """ + convert string of form ACCGTA into tensor [ 0, 1, 1, 2, 3, 0] + """ + result = np.zeros(len(sequence_string), dtype=np.uint8) + for n, char in enumerate(sequence_string): + integer = 0 if char == 'A' else (1 if char == 'C' else (2 if char == 'G' else 3)) + result[n] = integer + return result + + +def make_sequence_tensor(sequence_string: str) -> np.ndarray: + """ + convert string of form ACCGTA into 4-channel one-hot tensor + [ [1, 0, 0, 0, 0, 1], # A channel + [0, 1, 1, 0, 0, 0], # C channel + [0, 0, 0, 1, 0, 0], # G channel + [0, 0, 0, 0, 1, 0] ] # T channel + """ + result = np.zeros([4, len(sequence_string)]) + for n, char in enumerate(sequence_string): + channel = 0 if char == 'A' else (1 if char == 'C' else (2 if char == 'G' else 3)) + result[channel, n] = 1 + return result + + +def truncate_bases_if_necessary(bases: str): + return bases if len(bases) <= MAX_NUM_BASES_FOR_ENCODING else bases[:MAX_NUM_BASES_FOR_ENCODING] + + +# here we just butcher variants longer than 13 bases and chop!!! +def bases_as_base5_int(bases: str) -> int: + power_of_5 = 1 + bases_to_use = truncate_bases_if_necessary(bases) + result = 0 + for nuc in bases_to_use: + coeff = 1 if nuc == 'A' else (2 if nuc == 'C' else (3 if nuc == 'G' else 4)) + result += power_of_5 * coeff + power_of_5 *= 5 + return result + + +def bases5_as_base_string(base5: int) -> str: + result = "" + remaining = base5 + while remaining > 0: + digit = remaining % 5 + nuc = 'A' if digit == 1 else ('C' if digit == 2 else ('G' if digit == 3 else 'T')) + result += nuc + remaining = (remaining - digit) // 5 + return result + + +def convert_to_three_ints(n: int, base: int): + r3 = n % base + m3 = (n - r3) // base + r2 = m3 % base + m2 = (m3 - r2) // base + r1 = m2 % base + return r1, r2, r3 + + +def from_three_ints(r1, r2, r3, base): + return r3 + base*r2 + base*base*r1 + + +class Variant: + LENGTH = 10 # in order to compress to float16 we need three numbers for the large position integer and the alt, ref encodings + FLOAT_16_LIMIT = 2048 # float16 can *represent* bigger integers, but this is the limit of being reconstructed correctly + # if we get this wrong, the position encoding is wrong and the posterior data don't "line up" with the VCF data, + # causing very little filtering to actually occur + + def __init__(self, contig: int, position: int, ref: str, alt: str): + self.contig = contig + self.position = position + # note: it is very important to trim here, as early as possible, because truncating to 13 or fewer bases + # does not commute with trimming!!! If we are not consistent about trimming first, dataset variants and + # VCF variants might get inconsistent encodings!!! + self.ref, self.alt = trim_alleles_on_right(ref, alt) + + # note: if base strings are treated as numbers in base 5, uint32 (equivalent to two uint16's) can hold up to 13 bases + def to_np_array(self): + base = self.__class__.FLOAT_16_LIMIT + el1, el2, el3 = convert_to_three_ints(self.position, base) + el4, el5, el6 = convert_to_three_ints(bases_as_base5_int(self.ref), base) + el7, el8, el9 = convert_to_three_ints(bases_as_base5_int(self.alt), base) + return np.array([self.contig, el1, el2, el3, el4, el5, el6, el7, el8, el9], dtype=np.uint16) + + # do we need to specify that it's a uint32 array? + @classmethod + def from_np_array(cls, np_array: np.ndarray): + assert len(np_array) == cls.LENGTH + base = cls.FLOAT_16_LIMIT + position = from_three_ints(round(np_array[1]), round(np_array[2]), round(np_array[3]), base) + ref = bases5_as_base_string(from_three_ints(round(np_array[4]), round(np_array[5]), round(np_array[6]), base)) + alt = bases5_as_base_string(from_three_ints(round(np_array[7]), round(np_array[8]), round(np_array[9]), base)) + return cls(round(np_array[0]), position, ref, alt) + + def get_ref_as_int(self): + return bases_as_base5_int(self.ref) + + def get_alt_as_int(self): + return bases_as_base5_int(self.alt) + + +# count how many times a unit string is repeated at the beginning of a larger string +# eg 'ATATGGG', 'AT' -> 1; 'AGGGGG', 'G' -> 0; 'TTATTATTAGTTA', 'TTA' -> 3 +def count_leading_repeats(sequence: str, unit: str): + result = 0 + idx = 0 + unit_length = len(unit) + while (idx + unit_length - 1 < len(sequence)) and sequence[idx:idx + unit_length] == unit: + result += 1 + idx += unit_length + return result + + +# same, but at the end of a sequence +# eg 'ATATGGG', 'G' -> 3; 'AGGGGG', 'G' -> 5; 'TTATTATTAGTTA', 'TTA' -> 1 +def count_trailing_repeats(sequence: str, unit: str): + result = 0 + unit_length = len(unit) + idx = len(sequence) - unit_length # index at beginning of comparison eg 'GGATC', 'TC' starts at index 5 - 2 = 3, the 'T' + while idx >= 0 and sequence[idx:idx + unit_length] == unit: + result += 1 + idx -= unit_length + return result + + +def find_factors(n: int): + result = [] + for m in range(1, int(math.sqrt(n)) + 1): + if n % m == 0: + result.append(m) + if (n // m) > m: + result.append(n // m) + result.sort() + return result + + +# eg ACGACGACG, ACG -> True; TTATTA, TA -> False +def is_repeat(bases: str, unit: str): + unit_length = len(unit) + if len(bases) % unit_length == 0: + num_repeats = len(bases) // len(unit) + for repeat_idx in range(num_repeats): + start = repeat_idx * unit_length + if bases[start: start + unit_length] != unit: + return False + return True + else: + return False + + +# decompose an indel into its most basic repeated unit +# examples: "ATATAT" -> ("AT", 3); "AAAAA" -> ("A", 5); "TTGTTG" -> ("TTG", 2); "ATGTG" -> "ATGTG", 1 +def decompose_str_unit(indel_bases: str): + for unit_length in find_factors(len(indel_bases)): # note: these are sorted ascending + unit = indel_bases[:unit_length] + if is_repeat(indel_bases, unit): + return unit, (len(indel_bases) // unit_length) + return indel_bases, 1 + + +def get_str_info_array(ref_sequence_string: str, variant: Variant): + assert len(ref_sequence_string) % 2 == 1, "must be odd length to have well-defined middle" + middle_idx = (len(ref_sequence_string) - 1) // 2 + + ref, alt = variant.ref, variant.alt + + insertion_length = max(len(alt) - len(ref), 0) + deletion_length = max(len(ref) - len(alt), 0) + + if len(ref) == len(alt): + unit, num_units = alt, 1 + repeats_after = count_leading_repeats(ref_sequence_string[middle_idx + len(ref):], unit) + repeats_before = count_trailing_repeats(ref_sequence_string[:middle_idx], unit) + elif insertion_length > 0: + unit, num_units = decompose_str_unit(alt[1:]) # the inserted sequence is everything after the anchor base that matches ref + repeats_after = count_leading_repeats(ref_sequence_string[middle_idx + len(ref):], unit) + repeats_before = count_trailing_repeats(ref_sequence_string[:middle_idx+1], unit) # +1 accounts for the anchor base + else: + unit, num_units = decompose_str_unit(ref[1:]) # the deleted sequence is everything after the anchor base + # it's pretty arbitrary whether we include the deleted bases themselves as 'after' or not + repeats_after = count_leading_repeats(ref_sequence_string[middle_idx + len(alt):], unit) + repeats_before = count_trailing_repeats(ref_sequence_string[:middle_idx+1], unit) # likewise, account for the anchor base + # note that if indels are left-aligned (as they should be from the GATK) repeats_before really ought to be zero!! + return np.array([insertion_length, deletion_length, len(unit), num_units, repeats_before, repeats_after]) + + +class CountsAndSeqLks: + LENGTH = 6 + + def __init__(self, depth: int, alt_count: int, normal_depth: int, normal_alt_count: int, + seq_error_log_lk: float, normal_seq_error_log_lk: float): + self.depth = depth + self.alt_count = alt_count + self.normal_depth = normal_depth + self.normal_alt_count = normal_alt_count + self.seq_error_log_lk = seq_error_log_lk + self.normal_seq_error_log_lk = normal_seq_error_log_lk + + def to_np_array(self): + return np.array([self.depth, self.alt_count, self.normal_depth, self.normal_alt_count, self.seq_error_log_lk, self.normal_seq_error_log_lk]) + + @classmethod + def from_np_array(cls, np_array: np.ndarray): + assert len(np_array) == cls.LENGTH + return cls(round(np_array[0]), round(np_array[1]), round(np_array[2]), round(np_array[3]), float(np_array[4]), float(np_array[5])) + + +class TensorSizes: + LENGTH = 4 + + def __init__(self, ref_count: int, alt_count: int, ref_sequence_length: int, info_tensor_length: int): + self.ref_count = ref_count + self.alt_count = alt_count + self.ref_sequence_length = ref_sequence_length + self.info_tensor_length = info_tensor_length + + def to_np_array(self): + return np.array([self.ref_count, self.alt_count, self.ref_sequence_length, self.info_tensor_length]) + + @classmethod + def from_np_array(cls, np_array: np.ndarray): + assert len(np_array) == cls.LENGTH + return cls(round(np_array[0]), round(np_array[1]), round(np_array[2]), round(np_array[3])) + + +# This applies to both BaseDatum AND ArtifactDatum. ArtifactDatum is simply the special case where the ref seq and info +# tensor sizes are zero. +class OneDimensionalData: + REF_COUNT_IDX = 0 + ALT_COUNT_IDX = 1 + REF_SEQ_LENGTH_IDX = 2 + INFO_LENGTH_IDX = 3 + REF_SEQ_START_IDX = 4 + + # starting at index 4 is ref seq and then info, if these are not empty as in the case of artifact data + NUM_ELEMENTS_AFTER_INFO = 3 + Variant.LENGTH + CountsAndSeqLks.LENGTH # 1 for variant type, 1 for the label, 1 for the source (integer) + + VAR_TYPE_IDX = -NUM_ELEMENTS_AFTER_INFO + LABEL_IDX = -NUM_ELEMENTS_AFTER_INFO + 1 + SOURCE_IDX = -NUM_ELEMENTS_AFTER_INFO + 2 + + VARIANT_START_IDX = -NUM_ELEMENTS_AFTER_INFO + 3 + VARIANT_END_IDX = VARIANT_START_IDX + Variant.LENGTH + COUNTS_AND_SEQ_LKS_START_IDX = VARIANT_END_IDX + + # 1st four elements are tensor sizes: ref count, alt count, ref seq length, info length + # next is ref sequence as 1D array + # next is info 1D array + # variant type, label, source (each a single int) + # Variant (Variant.LENGTH elements) + # CountsAndSeqLks (CountsAndSeqLks.LENGTH elements) + def __init__(self, tensor_sizes: TensorSizes, ref_sequence_1d: np.ndarray, info_array_1d: np.ndarray, + variant_type: Variation, label: Label, source: int, variant: Variant, counts_and_seq_lks: CountsAndSeqLks, + array_override: np.ndarray = None): + if array_override is None: + # note: Label is an IntEnum so we can treat label as an integer + self.array = np.hstack((tensor_sizes.to_np_array(), ref_sequence_1d, info_array_1d, np.array([variant_type, label, source]), + variant.to_np_array(), counts_and_seq_lks.to_np_array())) + else: + self.array = array_override + + # used for sending the data from BaseDatum to ArtifactDatum + def copy_without_ref_seq_and_info(self): + # keep ref count (0) and alt count (1) the same; set ref seq length (2) and info tensor length (3) to zero + new_tensor_size_array = np.array([self.array[0], self.array[1], 0, 0]) + + # relies on the layout of tensor sizes, then ref seq and info, then the rest + new_array = np.hstack((new_tensor_size_array, self.array[-self.__class__.NUM_ELEMENTS_AFTER_INFO:])) + return OneDimensionalData(tensor_sizes=None, ref_sequence_1d=None, info_array_1d=None, variant_type=None, label=None, + source=None, variant=None, counts_and_seq_lks=None, array_override=new_array) + + def get_nbytes(self): + return self.array.nbytes + + def set_dtype(self, dtype): + self.array = self.array.astype(dtype) + + def get_ref_count(self) -> int: + return round(self.array[self.__class__.REF_COUNT_IDX]) + + def get_alt_count(self) -> int: + return round(self.array[self.__class__.ALT_COUNT_IDX]) + + def get_ref_seq_1d(self): + ref_seq_length = round(self.array[self.__class__.REF_SEQ_LENGTH_IDX]) + start = self.__class__.REF_SEQ_START_IDX + assert ref_seq_length > 0, "trying to get ref seq array when none exists -- is this used in an ArtifactDatum?" + return self.array[start:start + ref_seq_length] + + def get_info_1d(self): + ref_seq_length = round(self.array[self.__class__.REF_SEQ_LENGTH_IDX]) + info_length = round(self.array[self.__class__.INFO_LENGTH_IDX]) + start = self.__class__.REF_SEQ_START_IDX + ref_seq_length + assert info_length > 0, "trying to get info array when none exists -- is this used in an ArtifactDatum?" + return self.array[start:start + info_length] + + # note: this potentially resizes the array and requires the leading info tensor size element to be modified + # we do this in preprocessing when adding extra info to the info from GATK. + # this method should not otherwise be used!!! + def set_info_1d(self, new_info: np.ndarray): + ref_seq_length = round(self.array[self.__class__.REF_SEQ_LENGTH_IDX]) + old_info_length = round(self.array[self.__class__.INFO_LENGTH_IDX]) + + before_info = self.array[:self.__class__.REF_SEQ_START_IDX + ref_seq_length] + after_info = self.array[-self.__class__.NUM_ELEMENTS_AFTER_INFO:] + + self.array[self.__class__.INFO_LENGTH_IDX] = len(new_info) # update the info tensor size + self.array = np.hstack((before_info, new_info, after_info)) + + def get_variant_type(self) -> int: + return round(self.array[self.__class__.VAR_TYPE_IDX]) + + def set_variant_type(self, variant_type: Variation): + self.array[self.__class__.VAR_TYPE_IDX] = variant_type + + def get_label(self): + return self.array[self.__class__.LABEL_IDX] + + def set_label(self, label: Label): + self.array[self.__class__.LABEL_IDX] = label + + def get_source(self) -> int: + return round(self.array[self.__class__.SOURCE_IDX]) + + def set_source(self, source: int): + self.array[self.__class__.SOURCE_IDX] = source + + def get_variant(self): + return Variant.from_np_array(self.array[-self.__class__.NUM_ELEMENTS_AFTER_INFO + 3:-CountsAndSeqLks.LENGTH]) + + def get_variant_array(self): + return self.array[self.__class__.VARIANT_START_IDX:self.__class__.VARIANT_END_IDX] + + def get_counts_and_seq_lks(self): + return CountsAndSeqLks.from_np_array(self.array[self.__class__.COUNTS_AND_SEQ_LKS_START_IDX:]) + + def get_counts_and_seq_lks_array(self): + return self.array[self.__class__.COUNTS_AND_SEQ_LKS_START_IDX:] + + def to_np_array(self): + return self.array + + @classmethod + def from_np_array(cls, np_array: np.ndarray): + return cls(tensor_sizes=None, ref_sequence_1d=None, info_array_1d=None, variant_type=None, label=None, source=None, + variant=None, counts_and_seq_lks=None, array_override=np_array) + + +class BaseDatum: + """ + :param ref_sequence_1d 1D uint8 tensor of bases centered at the alignment start of the variant in form eg ACTG -> [0,1,3,2] + :param reads_2d 2D tensor, each row corresponding to one read; first all the ref reads, then all the alt reads + :param info_array_1d 1D tensor of information about the variant as a whole + :param label an object of the Label enum artifact, non-artifact, unlabeled + """ + def __init__(self, reads_2d: np.ndarray, ref_sequence_1d: np.ndarray, alt_count: int, info_array_1d: np.ndarray, + variant_type: Variation, label: Label, source: int, variant: Variant, counts_and_seq_lks: CountsAndSeqLks, + one_dimensional_data_override: OneDimensionalData = None): + # Note: if changing any of the data fields below, make sure to modify the size_in_bytes() method below accordingly! + + self.reads_2d = reads_2d + + if one_dimensional_data_override is None: + self.alt_count = alt_count + self.label = label + self.source = source + tensor_sizes = TensorSizes(ref_count=len(reads_2d) - alt_count, alt_count=alt_count, + ref_sequence_length=len(ref_sequence_1d), info_tensor_length=len(info_array_1d)) + self.other_stuff = OneDimensionalData(tensor_sizes, ref_sequence_1d, info_array_1d, variant_type, label, source, variant, counts_and_seq_lks) + else: + self.other_stuff = one_dimensional_data_override + self.alt_count = one_dimensional_data_override.get_alt_count() + self.label = one_dimensional_data_override.get_label() + self.source = one_dimensional_data_override.get_source() + self.set_dtype(np.float16) + + def set_dtype(self, dtype): + self.other_stuff.set_dtype(dtype) + self.reads_2d = self.reads_2d.astype(dtype) + + # gatk_info tensor comes from GATK and does not include one-hot encoding of variant type + @classmethod + def from_gatk(cls, ref_sequence_string: str, variant_type: Variation, ref_tensor: np.ndarray, alt_tensor: np.ndarray, + gatk_info_tensor: np.ndarray, label: Label, source: int, variant: Variant, counts_and_seq_lks: CountsAndSeqLks = None): + read_tensor = np.vstack([ref_tensor, alt_tensor]) if ref_tensor is not None else alt_tensor + alt_count = len(alt_tensor) + str_info = get_str_info_array(ref_sequence_string, variant) + info_tensor = np.hstack([gatk_info_tensor, str_info]) + result = cls(read_tensor, make_1d_sequence_tensor(ref_sequence_string), alt_count, info_tensor, variant_type, label, source, variant, counts_and_seq_lks) + result.set_dtype(np.float16) + return result + + def size_in_bytes(self): + return self.reads_2d.nbytes + self.other_stuff.get_nbytes() + + def get_reads_2d(self): + return self.reads_2d + + def get_1d_data(self) -> OneDimensionalData: + return self.other_stuff + + def get_variant_type(self) -> int: + return self.other_stuff.get_variant_type() + + def set_label(self, label: Label): + self.label = label + self.other_stuff.set_label(label) + + def set_source(self, source: int): + self.source = source + self.other_stuff.set_source(source) + + def get_ref_reads_2d(self) -> np.ndarray: + return self.reads_2d[:-self.alt_count] + + def get_alt_reads_2d(self) -> np.ndarray: + return self.reads_2d[-self.alt_count:] + + def get_info_tensor_1d(self) -> np.ndarray: + return self.other_stuff.get_info_1d() + + def set_info_tensor_1d(self, new_info: np.ndarray) -> np.ndarray: + return self.other_stuff.set_info_1d(new_info) + + def get_ref_sequence_1d(self) -> np.ndarray: + return self.other_stuff.get_ref_seq_1d() + + # returns two length-L 1D arrays of ref stacked on top of alt, with '4' in alt(ref) for deletions(insertions) + def get_ref_and_alt_sequences(self): + original_ref_array = self.get_ref_sequence_1d() # gives an array eg ATTTCGG -> [0,3,3,3,1,2,2] + assert len(original_ref_array) % 2 == 1, "ref sequence length should be odd" + middle_idx = (len(original_ref_array) - 1) // 2 + max_allele_length = middle_idx # just kind of a coincidence + variant = self.other_stuff.get_variant() + ref, alt = variant.ref[:max_allele_length], variant.alt[:max_allele_length] # these are strings, not integers + + if len(ref) >= len(alt): # substitution or deletion + ref_array = original_ref_array + alt_array = np.copy(ref_array) + deletion_length = len(ref) - len(alt) + # add the deletion value '4' to make the alt allele array as long as the ref allele + alt_allele_array = make_1d_sequence_tensor(alt) if deletion_length == 0 else np.hstack((make_1d_sequence_tensor(alt), np.full(shape=deletion_length, fill_value=4))) + alt_array[middle_idx: middle_idx + len(alt_allele_array)] = alt_allele_array + else: # insertion + insertion_length = len(alt) - len(ref) + before = original_ref_array[:middle_idx] + after = original_ref_array[middle_idx + len(ref):-insertion_length] + + alt_allele_array = make_1d_sequence_tensor(alt) + ref_allele_array = np.hstack((make_1d_sequence_tensor(ref), np.full(shape=insertion_length, fill_value=4))) + + ref_array = np.hstack((before, ref_allele_array, after)) + alt_array = np.hstack((before, alt_allele_array, after)) + + assert len(ref_array) == len(alt_array) + if len(ref) == len(alt): # SNV -- ref and alt ought to be different + assert alt_array[middle_idx] != ref_array[middle_idx] + else: # indel -- ref and alt are the same at the anchor base, then are different + assert alt_array[middle_idx + 1] != ref_array[middle_idx + 1] + return ref_array[:len(original_ref_array)], alt_array[:len(original_ref_array)] # this clipping may be redundant + + +def save_list_base_data(base_data: List[BaseDatum], file): + """ + note that torch.save works fine with numpy data + :param base_data: + :param file: + :return: + """ + # TODO: should I combine stack these into big arrays rather than leaving them as lists of arrays? + read_tensors = np.vstack([datum.get_reads_2d() for datum in base_data]) + other_stuff = np.vstack([datum.get_1d_data().to_np_array() for datum in base_data]) + torch.save([read_tensors, other_stuff], file) + + +def load_list_of_base_data(file) -> List[BaseDatum]: + """ + file is torch, output is converted back to numpy + :param file: + :return: + """ + # these are vstacked -- see save method above + read_tensors, other_stuffs = torch.load(file) + + result = [] + read_start_row = 0 + for other_stuff_numpy in other_stuffs: + other_stuff = OneDimensionalData.from_np_array(other_stuff_numpy) + read_count = other_stuff.get_ref_count() + other_stuff.get_alt_count() + read_end_row = read_start_row+read_count + + base_datum = BaseDatum(reads_2d=read_tensors[read_start_row:read_end_row], ref_sequence_1d=None, alt_count=None, info_array_1d=None, + variant_type=None, label=None, source=None, + variant=None, counts_and_seq_lks=None, one_dimensional_data_override=other_stuff) + read_start_row = read_end_row + result.append(base_datum) + + return result + + +class BaseBatch: + """ + Read sets have different sizes so we can't form a batch by naively stacking tensors. We need a custom way + to collate a list of Datum into a Batch + + collated batch contains: + 2D tensors of ALL ref (alt) reads, not separated by set. + number of reads in ref (alt) read sets, in same order as read tensors + info: 2D tensor of info fields, one row per variant + labels: 1D tensor of 0 if non-artifact, 1 if artifact + lists of original mutect2_data and site info + + Example: if we have two input data, one with alt reads [[0,1,2], [3,4,5] and the other with + alt reads [[6,7,8], [9,10,11], [12,13,14] then the output alt reads tensor is + [[0,1,2], [3,4,5], [6,7,8], [9,10,11], [12,13,14]] and the output counts are [2,3] + inside the model, the counts will be used to separate the reads into sets + """ + + def __init__(self, data: List[BaseDatum]): + # TODO: can we get rid of this potential bottleneck (might interact really badly with multiple workers)? + self._original_list = data + + # num_classes = 5 for A, C, G, T, and deletion / insertion + ref_alt = [torch.flatten(torch.permute(torch.nn.functional.one_hot(torch.from_numpy(np.vstack(item.get_ref_and_alt_sequences())).long(), num_classes=5), (0,2,1)), 0, 1) for item in data] # list of 2D (2x5)xL + # this is indexed by batch, length, channel (aka one-hot base encoding) + ref_alt_bcl = torch.stack(ref_alt) + + self.ref_sequences_2d = ref_alt_bcl + + list_of_ref_tensors = [item.get_ref_reads_2d() for item in data] + list_of_alt_tensors = [item.get_alt_reads_2d() for item in data] + self.reads_2d = torch.from_numpy(np.vstack(list_of_ref_tensors + list_of_alt_tensors)) + self.info_2d = torch.from_numpy(np.vstack([base_datum.get_info_tensor_1d() for base_datum in data])) + + ref_counts = IntTensor([len(datum.reads_2d) - datum.alt_count for datum in data]) + alt_counts = IntTensor([datum.alt_count for datum in data]) + labels = IntTensor([1 if item.label == Label.ARTIFACT else 0 for item in data]) + is_labeled_mask = IntTensor([0 if item.label == Label.UNLABELED else 1 for item in data]) + sources = IntTensor([item.source for item in data]) + variant_types = IntTensor([datum.get_variant_type() for datum in data]) + self.int_tensor = torch.vstack((ref_counts, alt_counts, labels, is_labeled_mask, sources, variant_types)) + + self._size = len(data) + + # pin memory for all tensors that are sent to the GPU + def pin_memory(self): + self.ref_sequences_2d = self.ref_sequences_2d.pin_memory() + self.reads_2d = self.reads_2d.pin_memory() + self.info_2d = self.info_2d.pin_memory() + self.int_tensor = self.int_tensor.pin_memory() + + return self + + def copy_to(self, device, non_blocking): + # For all non-tensor attributes, shallow copy is sufficient + new_batch = copy.copy(self) + new_batch.ref_sequences_2d = self.ref_sequences_2d.to(device, non_blocking=non_blocking) + new_batch.reads_2d = self.reads_2d.to(device, non_blocking=non_blocking) + new_batch.info_2d = self.info_2d.to(device, non_blocking=non_blocking) + new_batch.int_tensor = self.int_tensor.to(device, non_blocking=non_blocking) + return new_batch + + def original_list(self): + return self._original_list + + def get_reads_2d(self) -> Tensor: + return self.reads_2d + + def get_ref_counts(self) -> IntTensor: + return self.int_tensor[0, :] + + def get_alt_counts(self) -> IntTensor: + return self.int_tensor[1, :] + + # the original IntEnum format + def get_labels(self): + return self.int_tensor[2, :] + + def get_training_labels(self): + int_enum_labels = self.get_labels() + return 1.0 * (int_enum_labels == Label.ARTIFACT) + 0.5 * (int_enum_labels == Label.UNLABELED) + + def get_is_labeled_mask(self) -> IntTensor: + return self.int_tensor[3, :] + + def get_sources(self) -> IntTensor: + return self.int_tensor[4, :] + + def get_variant_types(self) -> IntTensor: + return self.int_tensor[5, :] + + def get_info_2d(self) -> Tensor: + return self.info_2d + + def get_ref_sequences_2d(self) -> Tensor: + return self.ref_sequences_2d + + def size(self) -> int: + return self._size + + +class ArtifactDatum: + """ + """ + def __init__(self, base_datum: BaseDatum, representation: Tensor): + # Note: if changing any of the data fields below, make sure to modify the size_in_bytes() method below accordingly! + assert representation.dim() == 1 + self.representation = torch.clamp(representation, MIN_FLOAT_16, MAX_FLOAT_16) + self.one_dimensional_data = base_datum.get_1d_data().copy_without_ref_seq_and_info() + self.set_dtype(np.float16) + + def set_dtype(self, dtype): + self.representation = self.representation.to(torch.float16) + self.one_dimensional_data.set_dtype(dtype) + + def get_ref_count(self) -> int: + return self.one_dimensional_data.get_ref_count() + + def get_alt_count(self) -> int: + return self.one_dimensional_data.get_alt_count() + + def get_depth(self) -> int: + return self.one_dimensional_data.get_counts_and_seq_lks().depth + + def get_variant_type(self) -> int: + return self.one_dimensional_data.get_variant_type() + + def get_label(self): + return self.one_dimensional_data.get_label() + + def get_source(self) -> int: + return self.one_dimensional_data.get_source() + + def size_in_bytes(self): + return self.representation.nbytes + self.one_dimensional_data.get_nbytes() + + def get_1d_data(self) -> OneDimensionalData: + return self.one_dimensional_data + + def is_labeled(self): + return self.get_label() != Label.UNLABELED + + +class ArtifactBatch: + def __init__(self, data: List[ArtifactDatum]): + self.representations_2d = torch.vstack([item.representation for item in data]) + self.other_stuff_array = torch.from_numpy(np.vstack([d.get_1d_data().to_np_array() for d in data])) + self._size = len(data) + + def get_variants(self) -> List[Variant]: + relevant_cols = self.other_stuff_array[:, OneDimensionalData.VARIANT_START_IDX:OneDimensionalData.VARIANT_END_IDX].numpy() + return [Variant.from_np_array(var_array_1d) for var_array_1d in relevant_cols] + + def get_counts_and_seq_lks(self) -> List[CountsAndSeqLks]: + relevant_cols = self.other_stuff_array[:, OneDimensionalData.COUNTS_AND_SEQ_LKS_START_IDX:].numpy() + return [CountsAndSeqLks.from_np_array(var_array_1d) for var_array_1d in relevant_cols] + + # get the original IntEnum format (VARIANT = 0, ARTIFACT = 1, UNLABELED = 2) labels + def get_labels(self) -> IntTensor: + return self.other_stuff_array[:, OneDimensionalData.LABEL_IDX].int() + + # TODO: left off here + # TODO: put in some breakpoints to double-check that this works + # convert to the training format of 0.0 / 0.5 / 1.0 for variant / unlabeled / artifact + # the 0.5 for unlabeled data is reasonable but should never actually be used due to the is_labeled mask + def get_training_labels(self) -> FloatTensor: + int_enum_labels = self.get_labels() + return 1.0 * (int_enum_labels == Label.ARTIFACT) + 0.5 * (int_enum_labels == Label.UNLABELED) + + # TODO: put in some breakpoints to double-check that this works + def get_is_labeled_mask(self): + int_enum_labels = self.get_labels() + return (int_enum_labels != Label.UNLABELED).int() + + def get_sources(self) -> IntTensor: + return self.other_stuff_array[:, OneDimensionalData.SOURCE_IDX].int() + + def get_variant_types(self) -> IntTensor: + result = self.other_stuff_array[:, OneDimensionalData.VAR_TYPE_IDX].int() + return result + + def get_ref_counts(self) -> IntTensor: + return self.other_stuff_array[:, OneDimensionalData.REF_COUNT_IDX].int() + + def get_alt_counts(self) -> IntTensor: + return self.other_stuff_array[:, OneDimensionalData.ALT_COUNT_IDX].int() + + # pin memory for all tensors that are sent to the GPU + def pin_memory(self): + self.representations_2d = self.representations_2d.pin_memory() + self.other_stuff_array = self.other_stuff_array.pin_memory() + return self + + def copy_to(self, device, dtype, non_blocking): + # For all non-tensor attributes, shallow copy is sufficient + # note that variants_array and counts_and_seq_lks_array are not used in training and are never sent to GPU + new_batch = copy.copy(self) + new_batch.representations_2d = self.representations_2d.to(device=device, dtype=dtype, non_blocking=non_blocking) + new_batch.other_stuff_array = self.other_stuff_array.to(device, dtype=dtype, non_blocking=non_blocking) + + return new_batch + + def get_representations_2d(self) -> Tensor: + return self.representations_2d + + def size(self) -> int: + return self._size + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/data/plain_text_data.py b/src/main/python/org/broadinstitute/hellbender/permutect/data/plain_text_data.py new file mode 100644 index 00000000000..f313d51095f --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/data/plain_text_data.py @@ -0,0 +1,226 @@ +""" +Functions for reading from plain text dataset files of the format + +UNLABELED # label +1:13118,A->G # locus and mutation +GAGGAAAGTGAGGTTGCCTGC # reference context +0.12 0.09 2.00 0.57 1.00 1.00 1.00 1.00 1.00 # variant-level info vector +5 4 0 0 # ref count, alt count, matched normal ref count, matched normal alt count +27 30 0 1 11 29 333 321 12 0 0 # one ref read vector per line +50 22 1 0 30 19 342 70 272 0 0 +27 31 0 1 32 17 236 203 33 0 0 +27 20 0 1 32 17 141 72 69 1 0 +21 28 1 0 49 0 232 49 183 1 0 +23 29 1 1 40 9 335 294 41 0 0 # one alt read vector per line +24 29 0 1 38 11 354 315 39 0 0 +24 30 0 1 36 13 351 314 37 0 0 +23 30 1 1 42 7 341 298 43 0 0 +51 13 0 0 # original ref, alt, normal ref, normal alt counts before downsampling +-108.131 # sequencing error log likelihood +-0.000 # matched normal sequencing error log likelihood +VARIANT +1:13302,C->T +GTCCTGGACACGCTGTTGGCC +0.00 0.00 0.00 1.00 1.00 2.00 1.00 1.00 1.00 +2 1 0 0 +24 29 0 0 11 21 338 11 327 0 0 +50 25 0 1 49 -8 355 303 52 0 0 +23 33 1 0 13 21 312 87 225 0 0 +69 4 0 0 +-11.327 +-0.000 +""" +from typing import List + +import numpy as np +import psutil +from sklearn.preprocessing import QuantileTransformer + +from permutect import utils +from permutect.data.base_datum import BaseDatum, Variant, CountsAndSeqLks, DEFAULT_NUMPY_FLOAT + +from permutect.utils import Label, Variation + +MAX_VALUE = 10000 +EPSILON = 0.00001 +QUANTILE_DATA_COUNT = 10000 + + +def read_data(dataset_file, only_artifacts: bool = False, source: int=0): + """ + generator that yields data from a plain text dataset file. + """ + with open(dataset_file) as file: + n = 0 + while label_str := file.readline().strip(): + label = Label.get_label(label_str) + passes_label_filter = (label == Label.ARTIFACT or not only_artifacts) + n += 1 + + # contig:position,ref->alt + variant_line = file.readline().strip() + locus, mutation = variant_line.split(",") + contig, position = map(int, locus.split(":")) # contig is an integer *index* from a sequence dictionary + # TODO: replace with tqdm progress bar by counting file in initial pass. It can't be that expensive. + if n % 100000 == 0: + print(f"{contig}:{position}") + ref_allele, alt_allele = mutation.strip().split("->") + + ref_sequence_string = file.readline().strip() + gatk_info_tensor = line_to_tensor(file.readline()) + ref_tensor_size, alt_tensor_size, normal_ref_tensor_size, normal_alt_tensor_size = map(int, file.readline().strip().split()) + + # the first column is read group index, which we currently discard + # later we're going to want to use this + ref_tensor = read_2d_tensor(file, ref_tensor_size)[:,1:] if ref_tensor_size > 0 else None + alt_tensor = read_2d_tensor(file, alt_tensor_size)[:,1:] if alt_tensor_size > 0 else None + + # normal_ref_tensor = read_2d_tensor(file, normal_ref_tensor_size) # not currently used + # normal_alt_tensor = read_2d_tensor(file, normal_alt_tensor_size) # not currently used + # round down normal tensors as well + + depth, alt_count, normal_depth, normal_alt_count = read_integers(file.readline()) + seq_error_log_lk = read_float(file.readline()) + normal_seq_error_log_lk = read_float(file.readline()) + + if alt_tensor_size > 0 and passes_label_filter: + variant = Variant(contig, position, ref_allele, alt_allele) + counts_and_seq_lks = CountsAndSeqLks(depth, alt_count, normal_depth, normal_alt_count, seq_error_log_lk, normal_seq_error_log_lk) + yield BaseDatum.from_gatk(ref_sequence_string, Variation.get_type(ref_allele, alt_allele), ref_tensor, + alt_tensor, gatk_info_tensor, label, source, variant, counts_and_seq_lks) + + +# if sources is None, source is set to zero +# if List is length-1, that's the source for all files +# otherwise each file has its own source int +def generate_normalized_data(dataset_files, max_bytes_per_chunk: int, sources: List[int]=None): + """ + given text dataset files, generate normalized lists of read sets that fit in memory + + In addition to quantile-normalizing read tensors it also enlarges the info tensors + :param dataset_files: + :param max_bytes_per_chunk: + :return: + """ + for n, dataset_file in enumerate(dataset_files): + buffer, bytes_in_buffer = [], 0 + read_quantile_transform = QuantileTransformer(n_quantiles=100, output_distribution='normal') + info_quantile_transform = QuantileTransformer(n_quantiles=100, output_distribution='normal') + + num_buffers_filled = 0 + source = 0 if sources is None else (sources[0] if len(sources) == 1 else sources[n]) + for base_datum in read_data(dataset_file, source=source): + buffer.append(base_datum) + bytes_in_buffer += base_datum.size_in_bytes() + if bytes_in_buffer > max_bytes_per_chunk: + print(f"Memory usage percent: {psutil.virtual_memory().percent:.1f}") + print(f"{bytes_in_buffer} bytes in chunk") + + normalize_buffer(buffer, read_quantile_transform, info_quantile_transform) + yield buffer + num_buffers_filled += 1 + buffer, bytes_in_buffer = [], 0 + # There will be some data left over, in general. Since it's small, use the last buffer's + # quantile transforms for better statistical power if it's from the same text file + if buffer: + normalize_buffer(buffer, read_quantile_transform, info_quantile_transform, refit_transforms=(num_buffers_filled==0)) + yield buffer + + +# this normalizes the buffer and also prepends new features to the info tensor +def normalize_buffer(buffer, read_quantile_transform, info_quantile_transform, refit_transforms=True): + EPSILON = 0.00001 # tiny quantity for jitter + + # 2D array. Rows are ref/alt reads, columns are read features + all_ref = np.vstack([datum.get_ref_reads_2d() for datum in buffer]) + all_reads = np.vstack([datum.reads_2d for datum in buffer]) + + # 2D array. Rows are read sets, columns are info features + all_info = np.vstack([datum.get_info_tensor_1d() for datum in buffer]) + + all_ref_jittered = all_ref + EPSILON * np.random.randn(*all_ref.shape) + all_reads_jittered = all_reads + EPSILON * np.random.randn(*all_reads.shape) + all_info_jittered = all_info + EPSILON * np.random.randn(*all_info.shape) + + num_read_features = all_ref.shape[1] + binary_read_columns = binary_column_indices(all_ref) # make sure not to use jittered arrays here! + + # 1 if is binary, 0 if not binary + binary_read_column_mask = np.zeros(num_read_features) + binary_read_column_mask[binary_read_columns] = 1 + + binary_info_columns = binary_column_indices(all_info) # make sure not to use jittered arrays here! + + if refit_transforms: # fit quantiles column by column (aka feature by feature) + read_quantile_transform.fit(all_ref_jittered) + info_quantile_transform.fit(all_info_jittered) + + # it's more efficient to apply the quantile transform to all reads at once, then split it back into read sets + all_reads_transformed = transform_except_for_binary_columns(all_reads_jittered, read_quantile_transform, binary_read_columns) + all_info_transformed = transform_except_for_binary_columns(all_info_jittered, info_quantile_transform, binary_info_columns) + + read_counts = np.array([len(datum.reads_2d) for datum in buffer]) + read_index_ranges = np.cumsum(read_counts) + + for n, datum in enumerate(buffer): + datum.reads_2d = all_reads_transformed[0 if n == 0 else read_index_ranges[n-1]:read_index_ranges[n]] + + # medians are an appropriate outlier-tolerant summary, except for binary columns where the mean makes more sense + alt_medians = np.median(datum.get_alt_reads_2d(), axis=0) + alt_means = np.mean(datum.get_alt_reads_2d(), axis=0) + + extra_info = binary_read_column_mask * alt_means + (1 - binary_read_column_mask) * alt_medians + datum.set_info_tensor_1d(np.hstack([extra_info, all_info_transformed[n]])) + + +def line_to_tensor(line: str) -> np.ndarray: + tokens = line.strip().split() + floats = [float(token) for token in tokens] + return np.clip(np.array(floats, dtype=DEFAULT_NUMPY_FLOAT), -MAX_VALUE, MAX_VALUE) + + +def read_2d_tensor(file, num_lines: int) -> np.ndarray: + if num_lines == 0: + return None + lines = [file.readline() for _ in range(num_lines)] + return np.vstack([line_to_tensor(line) for line in lines]) + + +def read_integers(line: str): + return map(int, line.strip().split()) + + +def read_float(line: str): + return float(line.strip().split()[0]) + + +def is_binary(column_tensor_1d: np.ndarray): + assert len(column_tensor_1d.shape) == 1 + return all(el.item() == 0 or el.item() == 1 for el in column_tensor_1d) + + +def binary_column_indices(tensor_2d: np.ndarray): + assert len(tensor_2d.shape) == 2 + return [n for n in range(tensor_2d.shape[1]) if is_binary(tensor_2d[:, n])] + + +def non_binary_column_indices(tensor_2d: np.ndarray): + assert len(tensor_2d.shape) == 2 + return [n for n in range(tensor_2d.shape[1]) if not is_binary(tensor_2d[:, n])] + + +# copy the unnormalized values of binary features (columns) +# we modify the normalized values in-place +def restore_binary_columns(normalized, original, binary_columns): + result = normalized + if len(normalized.shape) == 2: + result[:, binary_columns] = original[:, binary_columns] + elif len(normalized.shape) == 1: + result[binary_columns] = original[binary_columns] + else: + raise Exception("This is only for 1D or 2D tensors") + return result + + +def transform_except_for_binary_columns(tensor_2d, quantile_transform: QuantileTransformer, binary_column_indices): + return restore_binary_columns(quantile_transform.transform(tensor_2d), tensor_2d, binary_column_indices) \ No newline at end of file diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/data/posterior.py b/src/main/python/org/broadinstitute/hellbender/permutect/data/posterior.py new file mode 100644 index 00000000000..42017224820 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/data/posterior.py @@ -0,0 +1,162 @@ +import copy +import random +import math +from typing import List, Iterable + +import torch +from torch import IntTensor +from torch.utils.data import Dataset, DataLoader +from permutect.data.base_datum import Variant, CountsAndSeqLks, bases5_as_base_string + +from permutect import utils +from permutect.utils import Label, Variation + + +def variant_from_int_array(subarray) -> Variant: + contig = subarray[0].item() + position = subarray[1].item() + ref = bases5_as_base_string(subarray[2].item()) # ref and alt are the base-5 encoding as integers + alt = bases5_as_base_string(subarray[3].item()) + return Variant(contig, position, ref, alt) + + +class PosteriorDatum: + CONTIG = 0 + POSITION = 1 + REF = 2 + ALT = 3 + VAR_TYPE = 4 + DEPTH = 5 + ALT_COUNT = 6 + NORMAL_DEPTH = 7 + NORMAL_ALT_COUNT = 8 + LABEL = 9 + + SEQ_ERROR_LOG_LK = 0 + TLOD_FROM_M2 = 1 + NORMAL_SEQ_ERROR_LOG_LK = 2 + ALLELE_FREQUENCY = 3 + ARTIFACT_LOGIT = 4 + MAF = 5 + NORMAL_MAF = 6 + + def __init__(self, variant: Variant, counts_and_seq_lks: CountsAndSeqLks, allele_frequency: float, + artifact_logit: float, embedding: torch.Tensor, label: Label, maf: float, normal_maf: float): + self.embedding = embedding + + this_class = self.__class__ + self.int_array = torch.zeros(10, dtype=int) + self.int_array[this_class.CONTIG] = variant.contig + self.int_array[this_class.POSITION] = variant.position + self.int_array[this_class.REF] = variant.get_ref_as_int() # ref and alt are the base-5 encoding as integers + self.int_array[this_class.ALT] = variant.get_alt_as_int() + self.int_array[this_class.VAR_TYPE] = utils.Variation.get_type(variant.ref, variant.alt) # Variation is IntEnum so this is int + self.int_array[this_class.DEPTH] = counts_and_seq_lks.depth + self.int_array[this_class.ALT_COUNT] = counts_and_seq_lks.alt_count + self.int_array[this_class.NORMAL_DEPTH] = counts_and_seq_lks.normal_depth + self.int_array[this_class.NORMAL_ALT_COUNT] = counts_and_seq_lks.normal_alt_count + self.int_array[this_class.LABEL] = label + + self.float_array = torch.zeros(7, dtype=torch.float16) + self.float_array[this_class.SEQ_ERROR_LOG_LK] = counts_and_seq_lks.seq_error_log_lk + self.float_array[this_class.TLOD_FROM_M2] = -counts_and_seq_lks.seq_error_log_lk - math.log(counts_and_seq_lks.depth + 1) + self.float_array[this_class.NORMAL_SEQ_ERROR_LOG_LK] = counts_and_seq_lks.normal_seq_error_log_lk + self.float_array[this_class.ALLELE_FREQUENCY] = allele_frequency + self.float_array[this_class.ARTIFACT_LOGIT] = artifact_logit + self.float_array[this_class.MAF] = maf + self.float_array[this_class.NORMAL_MAF] = normal_maf + + def get_variant(self) -> Variant: + this_class = self.__class__ + subarray = self.int_array[this_class.CONTIG:this_class.ALT + 1] + return variant_from_int_array(subarray) + + def get_artifact_logit(self) -> float: + return self.float_array[self.__class__.ARTIFACT_LOGIT] + + +class PosteriorBatch: + + def __init__(self, data: List[PosteriorDatum]): + self.embeddings = torch.vstack([item.embedding for item in data]).float() + self.int_tensor = torch.vstack([item.int_array for item in data]) + self.float_tensor = torch.vstack([item.float_array for item in data]).float() + + self._size = len(data) + + def pin_memory(self): + self.embeddings = self.embeddings.pin_memory() + self.int_tensor = self.int_tensor.pin_memory() + self.float_tensor = self.float_tensor.pin_memory() + return self + + # dtype is just for floats!!! Better not convert the int tensor to a float accidentally! + def copy_to(self, device, dtype, non_blocking): + # For all non-tensor attributes, shallow copy is sufficient + new_batch = copy.copy(self) + + new_batch.embeddings = self.embeddings.to(device=device, dtype=dtype, non_blocking=non_blocking) + new_batch.int_tensor = self.int_tensor.to(device=device, non_blocking=non_blocking) + new_batch.float_tensor = self.float_tensor.to(device=device, dtype=dtype, non_blocking=non_blocking) + + return new_batch + + def get_variants(self) -> List[Variant]: + subarray_2d = self.int_tensor[:, PosteriorDatum.CONTIG:PosteriorDatum.ALT + 1] + return [variant_from_int_array(subarray) for subarray in subarray_2d] + + def get_variant_types(self) -> torch.IntTensor: + return self.int_tensor[:, PosteriorDatum.VAR_TYPE] + + def get_alt_counts(self) -> torch.Tensor: + return self.int_tensor[:, PosteriorDatum.ALT_COUNT] + + def get_depths(self) -> torch.Tensor: + return self.int_tensor[:, PosteriorDatum.DEPTH] + + def get_labels(self) -> torch.Tensor: + return self.int_tensor[:, PosteriorDatum.LABEL] + + def get_normal_alt_counts(self) -> torch.Tensor: + return self.int_tensor[:, PosteriorDatum.NORMAL_ALT_COUNT] + + def get_normal_depths(self) -> torch.Tensor: + return self.int_tensor[:, PosteriorDatum.NORMAL_DEPTH] + + def get_tlods_from_m2(self) -> torch.Tensor: + return self.float_tensor[:, PosteriorDatum.TLOD_FROM_M2] + + def get_allele_frequencies(self) -> torch.Tensor: + return self.float_tensor[:, PosteriorDatum.ALLELE_FREQUENCY] + + def get_artifact_logits(self) -> torch.Tensor: + return self.float_tensor[:, PosteriorDatum.ARTIFACT_LOGIT] + + def get_mafs(self) -> torch.Tensor: + return self.float_tensor[:, PosteriorDatum.MAF] + + def get_normal_mafs(self) -> torch.Tensor: + return self.float_tensor[:, PosteriorDatum.NORMAL_MAF] + + def size(self) -> int: + return self._size + + def get_normal_ref_counts(self) -> IntTensor: + return self.get_normal_depths() - self.get_normal_alt_counts() + + +class PosteriorDataset(Dataset): + def __init__(self, data: Iterable[PosteriorDatum], shuffle: bool = True): + self.data = data + + if shuffle: + random.shuffle(self.data) + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, index) -> PosteriorDatum: + return self.data[index] + + def make_data_loader(self, batch_size: int, pin_memory: bool = False, num_workers: int = 0): + return DataLoader(dataset=self, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, collate_fn=PosteriorBatch) diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/metrics/__init__.py b/src/main/python/org/broadinstitute/hellbender/permutect/metrics/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/metrics/evaluation_metrics.py b/src/main/python/org/broadinstitute/hellbender/permutect/metrics/evaluation_metrics.py new file mode 100644 index 00000000000..dab64f3f89c --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/metrics/evaluation_metrics.py @@ -0,0 +1,436 @@ +import math +from collections import defaultdict +from typing import List + +import numpy as np +import torch +from matplotlib import pyplot as plt +from torch.utils.tensorboard import SummaryWriter + +from permutect.data.base_datum import BaseBatch, ArtifactBatch +from permutect.metrics import plotting +from permutect.utils import Variation, Call, Epoch, StreamingAverage + +MAX_COUNT = 18 # counts above this will be truncated +MAX_LOGIT = 15 +NUM_DATA_FOR_TENSORBOARD_PROJECTION = 10000 + + +def round_up_to_nearest_three(x: int): + return math.ceil(x / 3) * 3 + + +def multiple_of_three_bin_index(x: int): + return (round_up_to_nearest_three(x)//3) - 1 # -1 because zero is not a bin + + +MAX_BIN = multiple_of_three_bin_index(MAX_COUNT) + +def multiple_of_three_bin_indices(counts: torch.Tensor): + return (torch.ceil(counts/3) - 1).int() + + +def multiple_of_three_bin_index_to_count(idx: int): + return 3 * (idx + 1) + + +# round logit to nearest int, truncate to range, ending up with bins 0. . . 2*max_logit +def logit_to_bin(logit): + return min(max(round(logit), -MAX_LOGIT), MAX_LOGIT) + MAX_LOGIT + + +def bin_center(bin_idx): + return bin_idx - MAX_LOGIT + + +NUM_COUNT_BINS = round_up_to_nearest_three(MAX_COUNT) // 3 # zero is not a bin + + +def make_count_bin_mask(bin_index: int, counts: torch.Tensor): + assert bin_index < NUM_COUNT_BINS + count_bin_bottom = 3*bin_index + 1 + count_bin_top = 3*bin_index + 3 + return (count_bin_bottom <= counts) * (counts <= count_bin_top) + + +# simple container class for holding results of the posterior model and other things that get output to the VCF and +# tensorboard analysis +class PosteriorResult: + def __init__(self, artifact_logit: float, posterior_probabilities, log_priors, spectra_lls, normal_lls, label, alt_count, depth, var_type, embedding): + self.artifact_logit = artifact_logit + self.posterior_probabilities = posterior_probabilities + self.log_priors = log_priors + self.spectra_lls = spectra_lls + self.normal_lls = normal_lls + self.label = label + self.alt_count = alt_count + self.depth = depth + self.variant_type = var_type + self.embedding = embedding + + +# keep track of losses during training of artifact model +class LossMetrics: + def __init__(self): + self.labeled_loss = StreamingAverage() + self.unlabeled_loss = StreamingAverage() + + self.labeled_loss_by_type = {variant_type: StreamingAverage() for variant_type in Variation} + self.labeled_loss_by_count = {bin_idx: StreamingAverage() for bin_idx in range(NUM_COUNT_BINS)} + + def get_labeled_loss(self) -> float: + return self.labeled_loss.get() + + def get_unlabeled_loss(self) -> float: + return self.unlabeled_loss.get() + + def write_to_summary_writer(self, epoch_type: Epoch, epoch: int, summary_writer: SummaryWriter, prefix: str = ""): + if not self.labeled_loss.is_empty(): + summary_writer.add_scalar(prefix + epoch_type.name + "/Labeled Loss", self.labeled_loss.get(), epoch) + + if not self.unlabeled_loss.is_empty(): + summary_writer.add_scalar(prefix + epoch_type.name + "/Unlabeled Loss", self.unlabeled_loss.get(), epoch) + + for bin_idx, loss in self.labeled_loss_by_count.items(): + if not loss.is_empty(): + summary_writer.add_scalar( + prefix + epoch_type.name + "/Labeled Loss/By Count/" + str(multiple_of_three_bin_index_to_count(bin_idx)), loss.get(), epoch) + + for var_type, loss in self.labeled_loss_by_type.items(): + if not loss.is_empty(): + summary_writer.add_scalar(prefix + epoch_type.name + "/Labeled Loss/By Type/" + var_type.name, loss.get(), epoch) + + # record the losses (indexed by batch dimension) by type and count, as well as the total loss not stratified by type and count + # input losses are NOT weighted, but when recorded they are multiplied by weights if given + # losses are divided into labeled and unlabeled + # TODO: put type hint batch: BaseBatch | ArtifactBatch once docker update + def record_losses(self, losses: torch.Tensor, batch, weights: torch.Tensor): + # handle total loss + labeled_weights, unlabeled_weights = batch.get_is_labeled_mask() * weights, (1 - batch.get_is_labeled_mask()) * weights + + self.labeled_loss.record_with_weights(losses, labeled_weights) + self.unlabeled_loss.record_with_weights(losses, unlabeled_weights) + + # Note that we currently do not track unlabeled loss by type or by count + # by type + variant_types = batch.get_variant_types() + + # weight for losses is product of 1) the weights 2) the is_labeled mask, 3) the variant type mask + for var_type_idx, var_type in enumerate(Variation): + variant_type_mask = (variant_types == var_type_idx) + self.labeled_loss_by_type[var_type].record_with_weights(losses, labeled_weights * variant_type_mask) + + # by count + if isinstance(batch, BaseBatch): + # rather than individually record each count, and therefore send lots of stuff off the GPU, we + # send everything with the same bin simultaneously + bins = multiple_of_three_bin_indices(batch.get_alt_counts()) + for count_bin in range(MAX_BIN + 1): + indices = (bins == count_bin) + self.labeled_loss_by_count[count_bin].record_with_weights(losses[indices], labeled_weights[indices]) + + elif isinstance(batch, ArtifactBatch): + for count_bin_index in range(NUM_COUNT_BINS): + count_bin_mask = make_count_bin_mask(count_bin_index, batch.get_alt_counts()) + self.labeled_loss_by_count[count_bin_index].record_with_weights(losses, labeled_weights * count_bin_mask) + + +# predictions_and_labels is list of (predicted logit, actual label) tuples +# adjustment is the logit threshold that maximizes accuracy -- basically we're trying to find the shift such that +# a logit of 0 expresses 50/50 confidence +# output is the amount to be SUBTRACTED from logit to get a final adjusted logit +def calculate_logit_adjustment(predictions_and_labels, use_harmonic_mean: bool = False): + _, adjustment = plotting.get_roc_data(predictions_and_labels, given_threshold=None, sens_prec=False, use_harmonic_mean=use_harmonic_mean) + return adjustment + + +class EvaluationMetricsForOneEpochType: + def __init__(self): + # indexed by variant type, then count bin, then logit bin + self.acc_vs_logit = { + var_type: [[StreamingAverage() for _ in range(2 * MAX_LOGIT + 1)] for _ in range(NUM_COUNT_BINS)] for + var_type in Variation} + + self.acc_vs_logit_all_counts = { + var_type: [StreamingAverage() for _ in range(2 * MAX_LOGIT + 1)] for var_type in Variation} + + # indexed by variant type, then call type (artifact vs variant), then count bin + self.acc_vs_cnt = {var_type: defaultdict(lambda: [StreamingAverage() for _ in range(NUM_COUNT_BINS)]) for + var_type in Variation} + + # variant type -> (predicted logit, actual label) + self.roc_data = {var_type: [] for var_type in Variation} + + # variant type, count -> (predicted logit, actual label) + self.roc_data_by_cnt = {var_type: [[] for _ in range(NUM_COUNT_BINS)] for var_type in Variation} + + # Variant is an IntEnum, so variant_type can also be integer + # label is 1 for artifact / error; 0 for non-artifact / true variant + # correct_call is boolean -- was the prediction correct? + # the predicted logit is the logit corresponding to the predicted probability that call in question is an artifact / error + def record_call(self, variant_type: Variation, predicted_logit: float, label: float, correct_call, alt_count: int, weight: float = 1.0): + count_bin_index = multiple_of_three_bin_index(min(MAX_COUNT, alt_count)) + self.acc_vs_cnt[variant_type][Call.SOMATIC if label < 0.5 else Call.ARTIFACT][count_bin_index].record(correct_call, weight) + self.acc_vs_logit[variant_type][count_bin_index][logit_to_bin(predicted_logit)].record(correct_call, weight) + self.acc_vs_logit_all_counts[variant_type][logit_to_bin(predicted_logit)].record(correct_call, weight) + + self.roc_data[variant_type].append((predicted_logit, label)) + self.roc_data_by_cnt[variant_type][count_bin_index].append((predicted_logit, label)) + + # return a list of tuples. This outer list is over the two labels, Call.SOMATIC and Call.ARTIFACT. Each tuple consists of + # (list of alt counts (x axis), list of accuracies (y axis), the label) + def make_data_for_accuracy_plot(self, var_type: Variation): + non_empty_count_bins_by_label = { + label: [idx for idx in range(NUM_COUNT_BINS) if not self.acc_vs_cnt[var_type][label][idx].is_empty()] + for label in self.acc_vs_cnt[var_type].keys()} + + return [([multiple_of_three_bin_index_to_count(idx) for idx in non_empty_count_bins_by_label[label]], + [self.acc_vs_cnt[var_type][label][idx].get() for idx in non_empty_count_bins_by_label[label]], + label.name) for label in self.acc_vs_cnt[var_type].keys()] + + # similar tuple format but now it's (list of logits, list of accuracies, count) + def make_data_for_calibration_plot(self, var_type: Variation): + non_empty_logit_bins = [ + [idx for idx in range(2 * MAX_LOGIT + 1) if not self.acc_vs_logit[var_type][count_idx][idx].is_empty()] + for count_idx in range(NUM_COUNT_BINS)] + return [([bin_center(idx) for idx in non_empty_logit_bins[count_idx]], + [self.acc_vs_logit[var_type][count_idx][idx].get() for idx in + non_empty_logit_bins[count_idx]], + str(multiple_of_three_bin_index_to_count(count_idx))) for count_idx in + range(NUM_COUNT_BINS)] + + # now it's (list of logits, list of accuracies) + def make_data_for_calibration_plot_all_counts(self, var_type: Variation): + non_empty_logit_bins = [idx for idx in range(2 * MAX_LOGIT + 1) if not self.acc_vs_logit_all_counts[var_type][idx].is_empty()] + return ([bin_center(idx) for idx in non_empty_logit_bins], + [self.acc_vs_logit_all_counts[var_type][idx].get() for idx in non_empty_logit_bins]) + + def plot_accuracy(self, var_type: Variation, axis): + acc_vs_cnt_x_y_lab_tuples = self.make_data_for_accuracy_plot(var_type) + plotting.simple_plot_on_axis(axis, acc_vs_cnt_x_y_lab_tuples, None, None) + + def plot_calibration(self, var_type: Variation, axis): + acc_vs_logit_x_y_lab_tuples = self.make_data_for_calibration_plot(var_type) + plotting.simple_plot_on_axis(axis, acc_vs_logit_x_y_lab_tuples, None, None) + + def plot_calibration_all_counts(self, var_type: Variation, axis): + logits_list, accuracies_list = self.make_data_for_calibration_plot_all_counts(var_type) + plotting.simple_plot_on_axis(axis, [(logits_list, accuracies_list, "calibration")], None, None) + + def plot_roc_curve(self, var_type: Variation, axis, given_threshold: float = None, sens_prec: bool = False): + plotting.plot_accuracy_vs_accuracy_roc_on_axis([self.roc_data[var_type]], [None], axis, given_threshold, sens_prec) + + def plot_roc_curves_by_count(self, var_type: Variation, axis, given_threshold: float = None, sens_prec: bool = False): + plotting.plot_accuracy_vs_accuracy_roc_on_axis(self.roc_data_by_cnt[var_type], + [str(multiple_of_three_bin_index_to_count(idx)) for idx in + range(NUM_COUNT_BINS)], axis, given_threshold, sens_prec) + + # return variant type, count bin -> logit adjustment to be subtracted (so that maximum accuracy is at threshold of logit = 0) + def calculate_logit_adjustments(self, use_harmonic_mean: bool = False): + result = {var_type: [0.0 for _ in range(NUM_COUNT_BINS)] for var_type in Variation} + for var_type in Variation: + for cbin in range(NUM_COUNT_BINS): + data = self.roc_data_by_cnt[var_type][cbin] + if data: # leave adjustment at 0 if no data + result[var_type][cbin] = calculate_logit_adjustment(data, use_harmonic_mean) + + return result + + +class EvaluationMetrics: + def __init__(self): + # we will have a map from epoch type to EvaluationMetricsForOneEpochType + self.metrics = defaultdict(EvaluationMetricsForOneEpochType) + + # list of (PosteriorResult, Call) tuples + self.mistakes = [] + + # Variant is an IntEnum, so variant_type can also be integer + # label is 1 for artifact / error; 0 for non-artifact / true variant + # correct_call is boolean -- was the prediction correct? + # the predicted logit is the logit corresponding to the predicted probability that call in question is an artifact / error + def record_call(self, epoch_type: Epoch, variant_type: Variation, predicted_logit: float, label: float, correct_call, alt_count: int, weight: float = 1.0): + self.metrics[epoch_type].record_call(variant_type, predicted_logit, label, correct_call, alt_count, weight) + + # track bad calls when filtering is given an optional evaluation truth VCF + def record_mistake(self, posterior_result: PosteriorResult, call: Call): + self.mistakes.append((posterior_result, call)) + + def make_mistake_histograms(self, summary_writer: SummaryWriter): + # indexed by call then var_type, inner is a list of posterior results with that call and var type + posterior_result_mistakes_by_call_and_var_type = defaultdict(lambda: defaultdict(list)) + for posterior_result, call in self.mistakes: + posterior_result_mistakes_by_call_and_var_type[call][posterior_result.variant_type].append(posterior_result) + + mistake_calls = posterior_result_mistakes_by_call_and_var_type.keys() + num_rows = len(mistake_calls) + + af_fig, af_axes = plt.subplots(num_rows, len(Variation), sharex='all', sharey='none', squeeze=False) + logit_fig, logit_axes = plt.subplots(num_rows, len(Variation), sharex='all', sharey='none', squeeze=False) + ac_fig, ac_axes = plt.subplots(num_rows, len(Variation), sharex='all', sharey='none', squeeze=False) + prob_fig, prob_axes = plt.subplots(num_rows, len(Variation), sharex='all', sharey='none', squeeze=False) + + for row_idx, mistake_call in enumerate(mistake_calls): + for var_type in Variation: + posterior_results = posterior_result_mistakes_by_call_and_var_type[mistake_call][var_type] + + af_data = [pr.alt_count / pr.depth for pr in posterior_results] + plotting.simple_histograms_on_axis(af_axes[row_idx, var_type], [af_data], [""], 20) + + ac_data = [pr.alt_count for pr in posterior_results] + plotting.simple_histograms_on_axis(ac_axes[row_idx, var_type], [ac_data], [""], 20) + + logit_data = [pr.artifact_logit for pr in posterior_results] + plotting.simple_histograms_on_axis(logit_axes[row_idx, var_type], [logit_data], [""], 20) + + # posterior probability assigned to this incorrect call + prob_data = [pr.posterior_probabilities[mistake_call] for pr in posterior_results] + plotting.simple_histograms_on_axis(prob_axes[row_idx, var_type], [prob_data], [""], 20) + + variation_types = [var_type.name for var_type in Variation] + row_names = [mistake.name for mistake in mistake_calls] + + plotting.tidy_subplots(af_fig, af_axes, x_label="alt allele fraction", y_label="", row_labels=row_names, column_labels=variation_types) + plotting.tidy_subplots(ac_fig, ac_axes, x_label="alt count", y_label="", row_labels=row_names, + column_labels=variation_types) + plotting.tidy_subplots(logit_fig, logit_axes, x_label="artifact logit", y_label="", row_labels=row_names, + column_labels=variation_types) + plotting.tidy_subplots(prob_fig, prob_axes, x_label="mistake call probability", y_label="", row_labels=row_names, + column_labels=variation_types) + + summary_writer.add_figure("mistake allele fractions", af_fig) + summary_writer.add_figure("mistake alt counts", ac_fig) + summary_writer.add_figure("mistake artifact logits", logit_fig) + summary_writer.add_figure("probability assigned to mistake calls", prob_fig) + + def make_plots(self, summary_writer: SummaryWriter, given_thresholds=None, sens_prec: bool = False, epoch: int = None): + # given_thresholds is a dict from Variation to float (logit-scaled) used in the ROC curves + keys = self.metrics.keys() + num_rows = len(keys) + # grid of figures -- rows are epoch types, columns are variant types + # each subplot has two line graphs of accuracy vs alt count, one each for artifact, non-artifact + acc_vs_cnt_fig, acc_vs_cnt_axes = plt.subplots(num_rows, len(Variation), sharex='all', sharey='all', squeeze=False) + roc_fig, roc_axes = plt.subplots(num_rows, len(Variation), sharex='all', sharey='all', squeeze=False, figsize=(4 * len(Variation), 4 * len(keys)), dpi=200) + cal_fig, cal_axes = plt.subplots(num_rows, len(Variation), sharex='all', sharey='all', squeeze=False) + cal_fig_all_counts, cal_axes_all_counts = plt.subplots(num_rows, len(Variation), sharex='all', sharey='all', squeeze=False) + roc_by_cnt_fig, roc_by_cnt_axes = plt.subplots(num_rows, len(Variation), sharex='all', sharey='all', squeeze=False, figsize=(4 * len(Variation), 4 * len(keys)), dpi=200) + + for row_idx, key in enumerate(keys): + metric = self.metrics[key] + for var_type in Variation: + given_threshold = None if given_thresholds is None else given_thresholds[var_type] + metric.plot_accuracy(var_type, acc_vs_cnt_axes[row_idx, var_type]) + metric.plot_calibration(var_type, cal_axes[row_idx, var_type]) + metric.plot_calibration_all_counts(var_type, cal_axes_all_counts[row_idx, var_type]) + metric.plot_roc_curve(var_type, roc_axes[row_idx, var_type], given_threshold, sens_prec) + metric.plot_roc_curves_by_count(var_type, roc_by_cnt_axes[row_idx, var_type], given_threshold, sens_prec) + # done collecting stats for all loaders and filling in subplots + + nonart_label = "sensitivity" if sens_prec else "non-artifact accuracy" + art_label = "precision" if sens_prec else "artifact accuracy" + + variation_types = [var_type.name for var_type in Variation] + row_names = [epoch_type.name for epoch_type in self.metrics.keys()] + plotting.tidy_subplots(acc_vs_cnt_fig, acc_vs_cnt_axes, x_label="alt count", y_label="accuracy", row_labels=row_names, column_labels=variation_types) + plotting.tidy_subplots(roc_fig, roc_axes, x_label=nonart_label, y_label=art_label, row_labels=row_names, column_labels=variation_types) + plotting.tidy_subplots(roc_by_cnt_fig, roc_by_cnt_axes, x_label=nonart_label, y_label=art_label, row_labels=row_names, column_labels=variation_types) + plotting.tidy_subplots(cal_fig, cal_axes, x_label="predicted logit", y_label="accuracy", row_labels=row_names, column_labels=variation_types) + plotting.tidy_subplots(cal_fig_all_counts, cal_axes_all_counts, x_label="predicted logit", y_label="accuracy", row_labels=row_names, column_labels=variation_types) + + summary_writer.add_figure("accuracy by alt count", acc_vs_cnt_fig, global_step=epoch) + summary_writer.add_figure(" accuracy by logit output by count", cal_fig, global_step=epoch) + summary_writer.add_figure(" accuracy by logit output", cal_fig_all_counts, global_step=epoch) + summary_writer.add_figure("sensitivity vs precision" if sens_prec else "variant accuracy vs artifact accuracy", roc_fig, global_step=epoch) + summary_writer.add_figure("sensitivity vs precision by alt count" if sens_prec else "variant accuracy vs artifact accuracy by alt count", roc_by_cnt_fig, global_step=epoch) + + +def sample_indices_for_tensorboard(indices: List[int]): + indices_np = np.array(indices) + + if len(indices_np) <= NUM_DATA_FOR_TENSORBOARD_PROJECTION: + return indices_np + + idx = np.random.choice(len(indices_np), size=NUM_DATA_FOR_TENSORBOARD_PROJECTION, replace=False) + return indices_np[idx] + + +class EmbeddingMetrics: + TRUE_POSITIVE = "true-positive" + FALSE_POSITIVE = "false-positive" + TRUE_NEGATIVE_ARTIFACT = "true-negative-artifact" # distinguish these because artifact and eg germline should embed differently + TRUE_NEGATIVE_NONARTIFACT = "true-negative-nonartifact" + TRUE_NEGATIVE = "true-negative" + FALSE_NEGATIVE_ARTIFACT = "false-negative-artifact" + TRUE_NEGATIVE_SEQ_ERROR = "true-negative-seq-error" + + def __init__(self): + # things we will collect for the projections + self.label_metadata = [] # list (extended by each batch) 1 if artifact, 0 if not + self.correct_metadata = [] # list (extended by each batch), 1 if correct prediction, 0 if not + self.type_metadata = [] # list of lists, strings of variant type + self.truncated_count_metadata = [] # list of lists + self.representations = [] # list of 2D tensors (to be stacked into a single 2D tensor), representations over batches + + def output_to_summary_writer(self, summary_writer: SummaryWriter, prefix: str = "", is_filter_variants: bool = False, epoch: int = None): + # downsample to a reasonable amount of UMAP data + all_metadata = list(zip(self.label_metadata, self.correct_metadata, self.type_metadata, self.truncated_count_metadata)) + + indices_by_correct_status = defaultdict(list) + + for n, correct_status in enumerate(self.correct_metadata): + indices_by_correct_status[correct_status].append(n) + + # note that if we don't have labeled truth, everything is boring + all_indices = set(range(len(all_metadata))) + interesting_indices = set(indices_by_correct_status[EmbeddingMetrics.TRUE_POSITIVE] + + indices_by_correct_status[EmbeddingMetrics.FALSE_POSITIVE] + + indices_by_correct_status[EmbeddingMetrics.FALSE_NEGATIVE_ARTIFACT]) + boring_indices = all_indices - interesting_indices + + '''if is_filter_variants: + boring_indices = np.array(indices_by_correct_status["unknown"] + indices_by_correct_status[EmbeddingMetrics.TRUE_NEGATIVE_ARTIFACT]) + + # if we have labeled truth, keep a few "boring" true negatives around; otherwise we only have "unknown"s + boring_count = len(interesting_indices) // 3 if len(interesting_indices) > 0 else len(boring_indices) + boring_to_keep = boring_indices[np.random.choice(len(boring_indices), size=boring_count, replace=False)] + idx = np.hstack((boring_to_keep, interesting_indices)) + + idx = np.random.choice(len(all_metadata), size=min(NUM_DATA_FOR_TENSORBOARD_PROJECTION, len(all_metadata)), replace=False) +''' + + stacked_representations = torch.vstack(self.representations) + + # read average embeddings stratified by variant type + for variant_type in Variation: + variant_name = variant_type.name + indices = set([n for n, type_name in enumerate(self.type_metadata) if type_name == variant_name]) + + interesting = interesting_indices & indices + boring = boring_indices & indices + boring_count = max(len(interesting) // 3, 100) if is_filter_variants else len(boring) + boring_to_keep = np.array([int(n) for n in boring])[np.random.choice(len(boring), size=boring_count, replace=False)] + idx = sample_indices_for_tensorboard(np.hstack((boring_to_keep, np.array([int(n) for n in interesting])))) + + summary_writer.add_embedding(stacked_representations[idx], + metadata=[all_metadata[round(n)] for n in idx.tolist()], + metadata_header=["Labels", "Correctness", "Types", "Counts"], + tag=prefix+"embedding for variant type " + variant_name, global_step=epoch) + + # read average embeddings stratified by alt count + for count_bin in range(NUM_COUNT_BINS): + count = multiple_of_three_bin_index_to_count(count_bin) + indices = set([n for n, alt_count in enumerate(self.truncated_count_metadata) if alt_count == str(count)]) + interesting = interesting_indices & indices + boring = boring_indices & indices + boring_count = max(len(interesting) // 3, 100) if is_filter_variants else len(boring) + boring_to_keep = np.array([int(n) for n in boring])[np.random.choice(len(boring), size=boring_count, replace=False)] + idx = sample_indices_for_tensorboard(np.hstack((boring_to_keep, np.array([int(n) for n in interesting])))) + + if len(idx) > 0: + summary_writer.add_embedding(stacked_representations[idx], + metadata=[all_metadata[round(n)] for n in idx.tolist()], + metadata_header=["Labels", "Correctness", "Types", "Counts"], + tag=prefix+"embedding for alt count " + str(count), global_step=epoch) + + + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/metrics/plotting.py b/src/main/python/org/broadinstitute/hellbender/permutect/metrics/plotting.py new file mode 100644 index 00000000000..31383f12b61 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/metrics/plotting.py @@ -0,0 +1,257 @@ +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +import math +from typing import List + + +# one or more simple plots of y data vs x data on shared axes +def simple_plot(x_y_lab_tuples, x_label, y_label, title): + fig = plt.figure() + curve = fig.gca() + labels_present = False + for (x, y, lab) in x_y_lab_tuples: + if lab is not None: + curve.plot(x, y, label=lab) + labels_present = True + else: + curve.plot(x, y) + curve.set_title(title) + curve.set_xlabel(x_label) + curve.set_ylabel(y_label) + if labels_present: + curve.legend() + return fig, curve + + +def simple_plot_on_axis(ax, x_y_lab_tuples, x_label, y_label): + labels_present = False + for (x, y, lab) in x_y_lab_tuples: + if lab is not None: + ax.plot(x, y, label=lab) + labels_present = True + else: + ax.plot(x, y) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + if labels_present: + ax.legend() + + +def simple_histograms_on_axis(ax, list_of_histogram_data, list_of_labels, num_bins): + ax.hist(list_of_histogram_data, bins=num_bins, alpha=0.5, label=list_of_labels) + + +# apply grouped bar plot to an axis (subplot) object +# heights by category is a dict of category to bar heights, where the nth bar height +# corresponds to the nth x label +def grouped_bar_plot_on_axis(ax, heights_by_category, x_labels, y_label): + spacing = 5 + bar_width = 0.7 * spacing / len(heights_by_category) + + for n, (category, heights) in enumerate(heights_by_category.items()): + offset = n * bar_width + x_positions = [offset + spacing*i for i in range(len(heights))] + ax.bar(x_positions, heights, width=bar_width, edgecolor='white', label=category) + + # Add xticks on the middle of the group bars + # plt.xlabel('group', fontweight='bold') + ticks_offset = bar_width * len(heights_by_category)/2 + ax.set_xticks([ticks_offset + spacing*i for i in range(len(x_labels))], labels=x_labels) + plt.setp(ax.get_xticklabels(), rotation=90) + ax.set_ylabel(y_label) + ax.legend() + + +# heights by category is a dict of category to bar heights, where the nth bar height +# corresponds to the nth x label +def grouped_bar_plot(heights_by_category, x_labels, y_label): + fig, ax = plt.subplots() + grouped_bar_plot_on_axis(ax, heights_by_category, x_labels, y_label) + return fig, ax + + +# labels are 0 for non-artifact, 1 for artifact +# predictions_and_labels has form [[(pred, label), (pred, label). . . for roc 1], [likewise for roc 2] etc] +# predictions are logits, not probabilities!! +def plot_accuracy_vs_accuracy_roc_on_axis(lists_of_predictions_and_labels, curve_labels, axis, given_threshold: float = None, sens_prec: bool = False): + x_y_lab_tuples = [] + small_dots = [] + big_dots = [] + for predictions_and_labels, curve_label in zip(lists_of_predictions_and_labels, curve_labels): + thresh_and_accs, _ = get_roc_data(predictions_and_labels, given_threshold, sens_prec, use_harmonic_mean=True) + x_y_lab_tuples.append(([x[1] for x in thresh_and_accs], [x[2] for x in thresh_and_accs], curve_label)) + + for threshold, art_acc, non_art_acc in thresh_and_accs: + if threshold == 0: + big_dots.append((art_acc, non_art_acc, 'rs')) # red square + elif given_threshold is not None and abs(threshold - given_threshold) < 0.001: + big_dots.append((art_acc, non_art_acc, 'kd')) # black diamond + else: + small_dots.append((art_acc, non_art_acc, 'go')) # green circle + + simple_plot_on_axis(axis, x_y_lab_tuples, "precision" if sens_prec else "artifact accuracy", "sensitivity" if sens_prec else "non-artifact accuracy") + for x, y, spec in small_dots: + axis.plot(x, y, spec, markersize=2,label="") # point + for x, y, spec in big_dots: + axis.plot(x, y, spec, markersize=6,label="") # point + + +# similar to the above, but labels are not known and we just have the predicted error probabilities +# we generate a theoretical ROC curve i.e. what the ROC curve would be if these predicted probabilities were +# perfectly calibrated +# labels are 0 for non-artifact, 1 for artifact +# predicted_error_probs is a list of list of floats between 0 and 1 +# curve labels is a list of strings +def plot_theoretical_roc_on_axis(predicted_error_probs, curve_labels, axis): + x_y_lab_tuples = [] + dots = [] + best_thresholds = [] + for error_probs, curve_label in zip(predicted_error_probs, curve_labels): + thresh_and_accs, best_threshold = get_theoretical_roc_data(error_probs) # best threshold is (threshold, art accuracay, non-art accuracy) + x_y_lab_tuples.append(([x[1] for x in thresh_and_accs], [x[2] for x in thresh_and_accs], curve_label)) + + for threshold, art_acc, non_art_acc in thresh_and_accs: + dots.append((art_acc, non_art_acc, 'go')) + dots.append((best_threshold[1], best_threshold[2], 'ro')) + best_thresholds.append(best_threshold) + + simple_plot_on_axis(axis, x_y_lab_tuples, "precision", "sensitivity") + for x, y, spec in dots: + axis.plot(x, y, spec, markersize=2,label="") # point + return best_thresholds + + +# input is list of (predicted artifact logit, binary artifact/non-artifact label) tuples +# 1st output is (threshold, accuracy on artifacts, accuracy on non-artifacts) tuples +# 2nd output is the threshold that maximizes mean (by default harmonic mean, otherwise artithmetic) of accuracy on artifacts and accuracy on non-artifacts +def get_roc_data(predictions_and_labels, given_threshold: float = None, sens_prec: bool = False, use_harmonic_mean: bool = True): + predictions_and_labels.sort(key=lambda p_and_l: p_and_l[0]) # sort from least to greatest artifact logit + num_calls = len(predictions_and_labels) + total_artifact = sum([label for _, label in predictions_and_labels]) + 0.0001 + total_non_artifact = num_calls - total_artifact + 0.0002 + # start at threshold = -infinity; that is, everything is called an artifact, and pick up one variant at a time + thresh_and_accs = [] # tuples of threshold, accuracy on artifacts, accuracy on non-artifacts + art_found, non_art_found = total_artifact, 0 + best_art_acc, best_non_art_acc = 1, 0 + next_threshold = -10 + best_threshold, best_mean = -99999, 0 + given_threshold_reached = (given_threshold is None) + for pred_logit, label in predictions_and_labels: + art_found -= label # if labeled as artifact, one artifact has slipped below threshold + non_art_found += (1 - label) # if labeled as non-artifact, one non-artifact has been gained + art_acc, non_art_acc = art_found / total_artifact, non_art_found / total_non_artifact + + # stuff for sensitivity-precision mode + tp = non_art_found # non-artifacts that pass threshold are true positives + fp = total_artifact - art_found # artifacts that do not fail threshold are false positives + + # in sensitivity-precision mode we care about the precision, not the absolute accuracy of artifact calls + art_metric = (tp/(tp + fp)) if sens_prec else art_acc + + harmonic_mean = 0 if (art_metric == 0 or non_art_acc == 0) else 1 / ((1 / art_metric) + (1 / non_art_acc)) + arithmetic_mean = (art_acc + non_art_acc) / 2 + mean = harmonic_mean if use_harmonic_mean else arithmetic_mean + + if mean > best_mean: + best_mean = mean + best_threshold = pred_logit + best_art_acc, best_non_art_acc = art_acc, non_art_acc + + if pred_logit > next_threshold: + thresh_and_accs.append((next_threshold, art_metric, non_art_acc)) + next_threshold = math.ceil(pred_logit) + + # the first time we reach a logit greater than the given threshold, we are basically *at* that threshold + if not given_threshold_reached and pred_logit > given_threshold: + thresh_and_accs.append((given_threshold, art_metric, non_art_acc)) + given_threshold_reached = True + + return thresh_and_accs, best_threshold + + +# input is list of artifact probabilities +# NOTE: this actually includes all errors, such as germline and seq error, but for later possible +# fixing code duplication with the above method we'll call it "artifact" +# 1st output is (threshold, accuracy on non-errors, accuracy on errors) tuples +# 2nd output is the threshold that maximizes harmonic mean of these two accuracies +def get_theoretical_roc_data(artifact_probs): + artifact_probs.sort(key=lambda p: p) # sort from least to greatest error probability + num_calls = len(artifact_probs) + total_artifact = sum([prob for prob in artifact_probs]) + 0.0001 + total_non_artifact = num_calls - total_artifact + 0.0002 + # start at threshold = 0; that is, everything is called an artifact, and pick up one variant at a time + # by increasing the probability threshold + thresh_and_accs = [] # tuples of threshold, accuracy on artifacts, accuracy on non-artifacts + art_found, non_art_found = total_artifact, 0 + next_threshold = -1 + best_threshold, best_harmonic_mean = (0, 1, 0), 0 # best threshold is threshold, precision, sensitivity + for prob in artifact_probs: + art_found -= prob # lose a fractional artifact + non_art_found += (1 - prob) # gain a fractional non-artifact + + tp = non_art_found # non-artifacts that pass threshold are true positives + fp = total_artifact - art_found # artifacts that do not fail threshold are false positives + + # in sensitivity-precision mode we care about the precision, not the absolute accuracy of artifact calls + sensitivity = tp / total_non_artifact + precision = tp / (tp + fp) + + harmonic_mean = 0 if (precision == 0 or sensitivity == 0) else 1 / ((1 / sensitivity) + (1 / precision)) + + if harmonic_mean > best_harmonic_mean: + best_harmonic_mean = harmonic_mean + best_threshold = (prob, precision, sensitivity) + + if prob > next_threshold: + thresh_and_accs.append((next_threshold, precision, sensitivity)) + next_threshold = math.ceil(prob*20)/20 # we are basically having thresholds of 0.05, 0.1, 0.15. . . + return thresh_and_accs, best_threshold + + +def tidy_subplots(figure: Figure, axes, x_label: str = None, y_label: str = None, + column_labels: List[str] = None, row_labels: List[str] = None, keep_axes_tick_labels=False): + """ + Combines various tidying operations on figures with subplots + 1. Removes the individual axis legends and replaces with a single figure legend. This assumes + that all axes have the same lines. + 2. Show x (y) labels and tick labels only in bottom row (leftmost column) + 3. Apply column headings and row labels + 4. Apply overall x and y labels to the figure as a whole + + figure matplotlib.figure.Figure + axes: 2D array of matplotlib.axes.Axes + + We assume these have been generated together via figure, axes = plt.subplots(. . .) + + """ + handles, labels = figure.get_axes()[0].get_legend_handles_labels() + figure.legend(handles, labels, loc='upper center') + + for ax in figure.get_axes(): + if not keep_axes_tick_labels: + ax.label_outer() # y tick labels only shown in leftmost column, x tick labels only shown on bottom row + ax.legend().set_visible(False) # hide the redundant identical subplot legends + + # remove the subplot labels and title -- these will be given manually to the whole figure and to the outer rows + ax.set_xlabel(None) + ax.set_ylabel(None) + ax.set_title(None) + + if x_label is not None: + figure.supxlabel(x_label) + + if y_label is not None: + figure.supylabel(y_label) + + if row_labels is not None: + assert len(row_labels) == len(axes) + for row_idx, label in enumerate(row_labels): + axes[row_idx][0].set_ylabel(label) # note that we use row 0 and set_title to make this a column heading + + if column_labels is not None: + assert len(column_labels) == len(axes[0]) + for col_idx, label in enumerate(column_labels): + axes[0][col_idx].set_title(label) + + figure.tight_layout() + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/parameters.py b/src/main/python/org/broadinstitute/hellbender/permutect/parameters.py new file mode 100644 index 00000000000..37ccadce525 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/parameters.py @@ -0,0 +1,146 @@ +from typing import List + +from permutect import constants + + +class BaseModelParameters: + """ + note that read layers and info layers exclude the input dimension + read_embedding_dimension: read tensors are linear-transformed to this dimension before + input to the transformer. This is also the output dimension of reads from the transformer + num_transformer_heads: number of attention heads in the read transformer. Must be a divisor + of the read_embedding_dimension + num_transformer_layers: number of layers of read transformer + """ + def __init__(self, read_layers: List[int], self_attention_hidden_dimension: int, + num_self_attention_layers: int, info_layers: List[int], aggregation_layers: List[int], + ref_seq_layers_strings: List[str], dropout_p: float, reweighting_range: float, batch_normalize: bool = False): + + self.read_layers = read_layers + self.info_layers = info_layers + self.ref_seq_layer_strings = ref_seq_layers_strings + self.self_attention_hidden_dimension = self_attention_hidden_dimension + self.num_self_attention_layers = num_self_attention_layers + self.aggregation_layers = aggregation_layers + self.dropout_p = dropout_p + self.reweighting_range = reweighting_range + self.batch_normalize = batch_normalize + + def output_dimension(self): + return self.aggregation_layers[-1] + + +def parse_base_model_params(args) -> BaseModelParameters: + read_layers = getattr(args, constants.READ_LAYERS_NAME) + info_layers = getattr(args, constants.INFO_LAYERS_NAME) + ref_seq_layer_strings = getattr(args, constants.REF_SEQ_LAYER_STRINGS_NAME) + self_attention_hidden_dimension = getattr(args, constants.SELF_ATTENTION_HIDDEN_DIMENSION_NAME) + num_self_attention_layers = getattr(args, constants.NUM_SELF_ATTENTION_LAYERS_NAME) + aggregation_layers = getattr(args, constants.AGGREGATION_LAYERS_NAME) + dropout_p = getattr(args, constants.DROPOUT_P_NAME) + reweighting_range = getattr(args, constants.REWEIGHTING_RANGE_NAME) + batch_normalize = getattr(args, constants.BATCH_NORMALIZE_NAME) + return BaseModelParameters(read_layers, self_attention_hidden_dimension, + num_self_attention_layers, info_layers, aggregation_layers, ref_seq_layer_strings, dropout_p, + reweighting_range, batch_normalize) + + +def add_base_model_params_to_parser(parser): + parser.add_argument('--' + constants.PRETRAINED_MODEL_NAME, required=False, type=str, help='optional pretrained base model to initialize training') + parser.add_argument('--' + constants.READ_LAYERS_NAME, nargs='+', type=int, required=True, + help='dimensions of hidden layers in the read embedding subnetwork, including the dimension of the embedding itself. ' + 'Negative values indicate residual skip connections') + parser.add_argument('--' + constants.SELF_ATTENTION_HIDDEN_DIMENSION_NAME, type=int, required=True, + help='hidden dimension of transformer keys and values') + parser.add_argument('--' + constants.NUM_SELF_ATTENTION_LAYERS_NAME, type=int, required=True, + help='number of symmetric gated MLP self-attention layers') + parser.add_argument('--' + constants.INFO_LAYERS_NAME, nargs='+', type=int, required=True, + help='dimensions of hidden layers in the info embedding subnetwork, including the dimension of the embedding itself. ' + 'Negative values indicate residual skip connections') + parser.add_argument('--' + constants.AGGREGATION_LAYERS_NAME, nargs='+', type=int, required=True, + help='dimensions of hidden layers in the aggregation subnetwork, excluding the dimension of input from lower subnetworks ' + 'and the dimension (1) of the output logit. Negative values indicate residual skip connections') + parser.add_argument('--' + constants.REF_SEQ_LAYER_STRINGS_NAME, nargs='+', type=str, required=True, + help='list of strings specifying convolution layers of the reference sequence embedding. For example ' + 'convolution/kernel_size=3/out_channels=64 pool/kernel_size=2 leaky_relu ' + 'convolution/kernel_size=3/dilation=2/out_channels=5 leaky_relu flatten linear/out_features=10') + parser.add_argument('--' + constants.DROPOUT_P_NAME, type=float, default=0.0, required=False, + help='dropout probability') + parser.add_argument('--' + constants.REWEIGHTING_RANGE_NAME, type=float, default=0.3, required=False, + help='magnitude of data augmentation by randomly weighted average of read embeddings. ' + 'a value of x yields random weights between 1 - x and 1 + x') + parser.add_argument('--' + constants.BATCH_NORMALIZE_NAME, action='store_true', + help='flag to turn on batch normalization') + + +# common parameters for training models +class TrainingParameters: + def __init__(self, batch_size: int, num_epochs: int, learning_rate: float = 0.001, + weight_decay: float = 0.01, num_workers: int = 0, num_calibration_epochs: int = 0, + inference_batch_size: int = 8192): + self.batch_size = batch_size + self.num_epochs = num_epochs + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.num_workers = num_workers + self.num_calibration_epochs = num_calibration_epochs + self.inference_batch_size = inference_batch_size + + +def parse_training_params(args) -> TrainingParameters: + learning_rate = getattr(args, constants.LEARNING_RATE_NAME) + weight_decay = getattr(args, constants.WEIGHT_DECAY_NAME) + batch_size = getattr(args, constants.BATCH_SIZE_NAME) + num_epochs = getattr(args, constants.NUM_EPOCHS_NAME) + num_calibration_epochs = getattr(args, constants.NUM_CALIBRATION_EPOCHS_NAME) + num_workers = getattr(args, constants.NUM_WORKERS_NAME) + inference_batch_size = getattr(args, constants.INFERENCE_BATCH_SIZE_NAME) + return TrainingParameters(batch_size, num_epochs, learning_rate, weight_decay, num_workers, num_calibration_epochs, inference_batch_size) + + +def add_training_params_to_parser(parser): + parser.add_argument('--' + constants.LEARNING_RATE_NAME, type=float, default=0.001, required=False, + help='learning rate') + parser.add_argument('--' + constants.WEIGHT_DECAY_NAME, type=float, default=0.0, required=False, + help='learning rate') + parser.add_argument('--' + constants.BATCH_SIZE_NAME, type=int, default=64, required=False, + help='batch size') + parser.add_argument('--' + constants.NUM_WORKERS_NAME, type=int, default=0, required=False, + help='number of subprocesses devoted to data loading, which includes reading from memory map, ' + 'collating batches, and transferring to GPU.') + parser.add_argument('--' + constants.NUM_EPOCHS_NAME, type=int, required=True, + help='number of epochs for primary training loop') + parser.add_argument('--' + constants.NUM_CALIBRATION_EPOCHS_NAME, type=int, default=0, required=False, + help='number of calibration-only epochs') + parser.add_argument('--' + constants.INFERENCE_BATCH_SIZE_NAME, type=int, default=8192, required=False, + help='batch size when performing model inference (not training)') + + +class ArtifactModelParameters: + def __init__(self, aggregation_layers: List[int], calibration_layers: List[int], + dropout_p: float = 0.0, batch_normalize: bool = False): + self.aggregation_layers = aggregation_layers + self.calibration_layers = calibration_layers + self.dropout_p = dropout_p + self.batch_normalize = batch_normalize + + +def parse_artifact_model_params(args) -> ArtifactModelParameters: + aggregation_layers = getattr(args, constants.AGGREGATION_LAYERS_NAME) + calibration_layers = getattr(args, constants.CALIBRATION_LAYERS_NAME) + dropout_p = getattr(args, constants.DROPOUT_P_NAME) + batch_normalize = getattr(args, constants.BATCH_NORMALIZE_NAME) + return ArtifactModelParameters(aggregation_layers, calibration_layers, dropout_p, batch_normalize) + + +def add_artifact_model_params_to_parser(parser): + parser.add_argument('--' + constants.AGGREGATION_LAYERS_NAME, nargs='+', type=int, required=True, + help='dimensions of hidden layers in the aggregation subnetwork, excluding the dimension of input from lower subnetworks ' + 'and the dimension (1) of the output logit. Negative values indicate residual skip connections') + parser.add_argument('--' + constants.CALIBRATION_LAYERS_NAME, nargs='+', type=int, required=True, + help='dimensions of hidden layers in the calibration subnetwork, excluding the dimension (1) of input logit and) ' + 'and the dimension (also 1) of the output logit.') + parser.add_argument('--' + constants.DROPOUT_P_NAME, type=float, default=0.0, required=False, + help='dropout probability') + parser.add_argument('--' + constants.BATCH_NORMALIZE_NAME, action='store_true', + help='flag to turn on batch normalization') \ No newline at end of file diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/__init__.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/__init__.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_artifact_model.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_artifact_model.py new file mode 100644 index 00000000000..92ea74024dd --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_artifact_model.py @@ -0,0 +1,128 @@ +from permutect.test.test_utils import artificial_data +from permutect.data.base_dataset import BaseDataset, make_test_data_loader +from permutect.data.base_datum import BaseDatum +from typing import Iterable +from permutect.architecture.artifact_model import ArtifactModel +from permutect.parameters import ArtifactModelParameters +from permutect import utils +from permutect.tools.train_model import TrainingParameters + +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator +import tempfile +from torch.utils.tensorboard import SummaryWriter + +BATCH_SIZE = 64 +CHUNK_SIZE = 100000 +NUM_EPOCHS = 50 +NUM_CALIBRATION_EPOCHS=10 +NUM_SPECTRUM_ITERATIONS = 100 +TRAINING_PARAMS = TrainingParameters(batch_size=BATCH_SIZE, num_epochs=NUM_EPOCHS, num_calibration_epochs=NUM_CALIBRATION_EPOCHS, reweighting_range=0.3) + +REF_SEQ_LAYER_STRINGS = ['convolution/kernel_size=3/out_channels=64', + 'pool/kernel_size=3', + 'leaky_relu', + 'convolution/kernel_size=3/dilation=1/out_channels=5', + 'leaky_relu', + 'flatten', + 'linear/out_features=10'] +SMALL_MODEL_PARAMS = ArtifactModelParameters(read_embedding_dimension=12, + num_transformer_heads=3, transformer_hidden_dimension=10, num_transformer_layers=2, + info_layers=[5, 5], aggregation_layers=[5, 5, 5, 5], calibration_layers=[6], + ref_seq_layers_strings=REF_SEQ_LAYER_STRINGS, + dropout_p=0.2, batch_normalize=False, learning_rate=0.001, weight_decay=0.01, alt_downsample=20) + + +# Note that the test methods in this class also cover batching, samplers, datasets, and data loaders +def train_model_and_write_summary(hyperparams: ArtifactModelParameters, training_params: TrainingParameters, + data: Iterable[BaseDatum], summary_writer: SummaryWriter = None): + dataset = BaseDataset(data=data) + big_dataset = BigReadSetDataset(batch_size=training_params.batch_size, dataset=dataset, num_workers=2) + model = ArtifactModel(params=hyperparams, num_read_features=dataset.num_read_features(), num_info_features=dataset.num_info_features(), ref_sequence_length=dataset.ref_sequence_length()).float() + + model.learn(big_dataset, training_params.num_epochs, training_params.num_calibration_epochs, summary_writer=summary_writer, + reweighting_range=training_params.reweighting_range, hyperparams=hyperparams) + model.evaluate_model_after_training({"training": big_dataset.generate_batches(utils.Epoch.TRAIN)}, summary_writer) + return model + + +def test_big_data(): + training_dataset_file = "/Users/davidben/mutect3/just-dream-1/dream1-normal-medium-training.dataset" + big_dataset = BigReadSetDataset(batch_size=64, max_bytes_per_chunk=int(100*1e6), dataset_files=[training_dataset_file], num_workers=2) + params = SMALL_MODEL_PARAMS + training_params = TRAINING_PARAMS + + with tempfile.TemporaryDirectory() as tensorboard_dir: + summary_writer = SummaryWriter(tensorboard_dir) + model = ArtifactModel(params=params, num_read_features=big_dataset.num_read_features, num_info_features=big_dataset.num_info_features, ref_sequence_length=big_dataset.ref_sequence_length).float() + model.learn(big_dataset, training_params.num_epochs, training_params.num_calibration_epochs, summary_writer=summary_writer, + reweighting_range=training_params.reweighting_range, hyperparams=params) + model.evaluate_model_after_training({"training": big_dataset.generate_batches(utils.Epoch.TRAIN)}, summary_writer) + + events = EventAccumulator(tensorboard_dir) + events.Reload() + + +def test_separate_gaussian_data(): + # in the test for alt count agnostic, we make training data where variant alt counts are much larger than artifact + # alt counts and test data with a low alt allele fraction + for test_alt_fraction_agnostic in (False, True): + data = artificial_data.make_two_gaussian_data(1000) if not test_alt_fraction_agnostic else \ + artificial_data.make_two_gaussian_data(1000, vaf=0.5, downsample_variants_to_match_artifacts=False, alt_downsampling=20) + params = SMALL_MODEL_PARAMS + training_params = TRAINING_PARAMS + + with tempfile.TemporaryDirectory() as tensorboard_dir: + summary_writer = SummaryWriter(tensorboard_dir) + model = train_model_and_write_summary(hyperparams=params, training_params=training_params, data=data, summary_writer=summary_writer) + + # TODO: migrate this old stuff to test for PosteriorModel + # test_vaf = 0.05 if test_alt_fraction_agnostic else 0.5 + # test_data = artificial_data.make_two_gaussian_data(1000, is_training_data=False, vaf=test_vaf, unlabeled_fraction=0.0) + # test_dataset = ReadSetDataset(data=test_data) + # test_loader = make_test_data_loader(test_dataset, BATCH_SIZE) + # model.learn_spectra(test_loader, NUM_SPECTRUM_ITERATIONS, summary_writer=summary_writer) + + events = EventAccumulator(tensorboard_dir) + events.Reload() + + # TODO: these have been replaced with images, so it's not so simple to check the output quality from the tensorboard + # TODO: for now I can put in a breakpoint and manually run tensorboard --logdir + # TODO: to spot check the figures + # assert events.Scalars('Variant Sensitivity')[0].value > 0.98 + # assert events.Scalars('Artifact Sensitivity')[0].value > 0.98 + + +def test_wide_and_narrow_gaussian_data(): + data = artificial_data.make_wide_and_narrow_gaussian_data(10000) + params = SMALL_MODEL_PARAMS + training_params = TRAINING_PARAMS + + with tempfile.TemporaryDirectory() as tensorboard_dir: + summary_writer = SummaryWriter(tensorboard_dir) + model = train_model_and_write_summary(hyperparams=params, training_params=training_params, data=data, summary_writer=summary_writer) + + events = EventAccumulator(tensorboard_dir) + events.Reload() + + +# TODO: this test currently fails -- almost everything is considered an artifact +# TODO: I must investigate +def test_strand_bias_data(): + data = artificial_data.make_random_strand_bias_data(1000, is_training_data=True) + params = SMALL_MODEL_PARAMS # TODO: change!!!!!!! + training_params = TRAINING_PARAMS + + with tempfile.TemporaryDirectory() as tensorboard_dir: + summary_writer = SummaryWriter(tensorboard_dir) + model = train_model_and_write_summary(hyperparams=params, training_params=training_params, data=data, summary_writer=summary_writer) + + test_data = artificial_data.make_random_strand_bias_data(1000, is_training_data=False, vaf=0.25, unlabeled_fraction=0.0) + test_dataset = BaseDataset(data=test_data) + test_loader = make_test_data_loader(test_dataset, BATCH_SIZE) + model.learn_spectra(test_loader, NUM_SPECTRUM_ITERATIONS, summary_writer=summary_writer) + + events = EventAccumulator(tensorboard_dir) + events.Reload() + + assert events.Scalars('Variant Sensitivity')[0].value > 0.90 + assert events.Scalars('Artifact Sensitivity')[0].value > 0.90 diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_dna_sequence_convolution.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_dna_sequence_convolution.py new file mode 100644 index 00000000000..90a65928020 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_dna_sequence_convolution.py @@ -0,0 +1,18 @@ +import torch + +from permutect.architecture.dna_sequence_convolution import DNASequenceConvolution + + +def test_constructor(): + input_length = 20 + layer_strings = ['convolution/kernel_size=3/out_channels=64', + 'pool/kernel_size=2', + 'leaky_relu', + 'convolution/kernel_size=3/dilation=2/out_channels=5', + 'leaky_relu', + 'flatten', + 'linear/out_features=10'] + model = DNASequenceConvolution(layer_strings, input_length) + + data = torch.randn(8, 4, input_length) + output = model(data) \ No newline at end of file diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_mlp.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_mlp.py new file mode 100644 index 00000000000..a2555b906e7 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_mlp.py @@ -0,0 +1,110 @@ +import torch +from permutect import utils + +from permutect.architecture.mlp import MLP + + +# test with artificial data where a*x = 0 is a perfect linear separator +def test_linearly_separable_data(): + input_dim = 3 + num_samples = 1000 + a = torch.rand(input_dim, 1) + x = torch.rand(num_samples, input_dim) + y = (torch.sign(torch.matmul(x,a)) + 1)/2 # labels are 0 / 1 + + layer_sizes = [input_dim, 1] + model = MLP(layer_sizes) + + loss_func = torch.nn.BCEWithLogitsLoss(reduction='mean') + optimizer = torch.optim.Adam(model.parameters()) + loss_list = [] + + num_epochs = 10000 + for epoch in range(num_epochs): + prediction = model.forward(x) + loss = loss_func(prediction, y) + utils.backpropagate(optimizer, loss) + loss_list.append(loss.item()) + + assert loss_list[-1] < 0.01 + assert loss_list[1000] < loss_list[0] + assert loss_list[2000] < loss_list[1000] + assert loss_list[3000] < loss_list[2000] + + pred = torch.sign(torch.sigmoid(model.forward(x))-0.5) + lab = torch.sign(y-0.5) + + errors = torch.sum(torch.abs((pred - lab)/2)).item() + + # we should have perfect accuracy + assert errors < 1 + + +# test with annular data where y = 1 when 1/3 < norm(x) < 2/3 +def test_annular_data(): + input_dim = 3 + num_samples = 1000 + x = torch.randn(num_samples, input_dim)/torch.sqrt(torch.tensor([input_dim])) + + norms = torch.norm(x, dim=1) + y = (torch.sign(norms - 0.33) * torch.sign(0.66 - norms) + 1)/2 + + layer_sizes = [input_dim, 5, 5, 5, 5, 1] + model = MLP(layer_sizes) + + loss_func = torch.nn.BCEWithLogitsLoss(reduction='mean') + optimizer = torch.optim.Adam(model.parameters()) + loss_list = [] + + num_epochs = 10000 + for epoch in range(num_epochs): + prediction = model.forward(x) + loss = loss_func(torch.squeeze(prediction), y) + utils.backpropagate(optimizer, loss) + loss_list.append(loss.item()) + + assert loss_list[-1] < 0.2 + + pred = torch.squeeze(torch.sign(torch.sigmoid(model.forward(x)) - 0.5)) + lab = torch.sign(y - 0.5) + + errors = torch.sum(torch.abs((pred - lab) / 2)).item() + + # we should have perfect accuracy + assert errors < 100 + + +# A single hidden layer should suffice to perfectly classify a 2D XOR pattern where +# the 1st and 3rd quadrants are labeled 0 and the 2nd and 4th are 1 +def test_xor_data(): + num_samples = 1000 + x = torch.randn(num_samples, 2) + y = (torch.sign(x[:, 0]*x[:, 1]) + 1)/2 # labels are 0 / 1 + + # Use a single hidden layer. Really only 2 neurons are needed, but more speeds convergence + layer_sizes = [2, 4, 1] + model = MLP(layer_sizes) + + loss_func = torch.nn.BCEWithLogitsLoss(reduction='mean') + optimizer = torch.optim.Adam(model.parameters()) + loss_list = [] + + num_epochs = 10000 + for epoch in range(num_epochs): + prediction = model.forward(x) + loss = loss_func(prediction.squeeze(), y) + utils.backpropagate(optimizer, loss) + loss_list.append(loss.item()) + + assert loss_list[-1] < 0.15 + assert loss_list[1000] < loss_list[0] + assert loss_list[2000] < loss_list[1000] + assert loss_list[3000] < loss_list[2000] + + pred = torch.sign(torch.sigmoid(model.forward(x).detach().squeeze())-0.5) + lab = torch.sign(y-0.5) + + errors = torch.sum(torch.abs((pred - lab)/2)).item() + + # we should have near-perfect accuracy + assert errors < num_samples / 50 diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_monotonic.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_monotonic.py new file mode 100644 index 00000000000..b2aa093d482 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_monotonic.py @@ -0,0 +1,121 @@ +import torch +from permutect import utils + +from permutect.architecture.monotonic import MonoDense + + +def test_is_monotonic(): + input_dim = 3 + num_samples = 100 + x = torch.randn(num_samples, input_dim) + x[:, 0] = torch.arange(num_samples) / num_samples # 0th column is sorted least to greatest + x[:,1] = 0.5 # other columns are constant + x[:, 2] = 0.7 + + # an initialized random model. We're not learning anything here. + model = MonoDense(input_dimension=input_dim, output_dimensions=[12,-2,-2,1], num_increasing=1, num_decreasing=0) + + prediction = model.forward(x).flatten() + + sorted_prediction = prediction.sort().values + + assert torch.sum(torch.abs(prediction - sorted_prediction)) < 0.00001 + + +# test with artificial data where y = a dot x where a is a positive vector +def test_monotonic_linear_data(): + input_dim = 3 + num_samples = 100 + a = torch.ones(input_dim) + x = torch.rand(num_samples, input_dim) + y = torch.sum(x*a, dim=1) # row-by-row dot product + + model = MonoDense(input_dimension=input_dim, output_dimensions=[1], num_increasing=input_dim, num_decreasing=0) + loss_func = torch.nn.MSELoss(reduction='mean') + optimizer = torch.optim.Adam(model.parameters()) + loss_list = [] + + num_epochs = 10000 + for epoch in range(num_epochs): + prediction = model.forward(x).resize_as(y) + loss = loss_func(prediction, y) + utils.backpropagate(optimizer, loss) + loss_list.append(loss.item()) + + assert loss_list[-1] < 0.01 + assert loss_list[1000] < loss_list[0] + assert loss_list[2000] < loss_list[1000] + assert loss_list[3000] < loss_list[2000] + + +# test with artificial data where y = x1 - x2 + x3^2; monotonic increasing in x1, decreasing in x2, and neither in x3 +# we need more than one layer to get the quadratic +def test_mix(): + input_dim = 3 + num_samples = 100 + x = torch.rand(num_samples, input_dim) + y = x[:, 0] - x[:, 1] + torch.square(x[:, 2]) + + model = MonoDense(input_dimension=input_dim, output_dimensions=[6, 6, 6, 1], num_increasing=1, num_decreasing=1) + loss_func = torch.nn.MSELoss(reduction='mean') + optimizer = torch.optim.Adam(model.parameters()) + loss_list = [] + + num_epochs = 10000 + for epoch in range(num_epochs): + prediction = model.forward(x).resize_as(y) + loss = loss_func(prediction, y) + utils.backpropagate(optimizer, loss) + loss_list.append(loss.item()) + + assert loss_list[-1] < 0.01 + assert loss_list[1000] < loss_list[0] + assert loss_list[2000] < loss_list[1000] + assert loss_list[3000] < loss_list[2000] + + +def test_cant_learn_non_monotonic(): + input_dim = 1 + num_samples = 100 + x = torch.randn(num_samples, input_dim) + y = torch.square(x) + + model = MonoDense(input_dimension=input_dim, output_dimensions=[6, 6, 6, 1], num_increasing=1, num_decreasing=0) + loss_func = torch.nn.MSELoss(reduction='mean') + optimizer = torch.optim.Adam(model.parameters()) + loss_list = [] + + num_epochs = 10000 + for epoch in range(num_epochs): + prediction = model.forward(x).resize_as(y) + loss = loss_func(prediction, y) + utils.backpropagate(optimizer, loss) + loss_list.append(loss.item()) + + prediction = model.forward(x).resize_as(y) + loss = loss_func(prediction, y) + assert loss.item() > 0.1 + + +def test_cubic(): + input_dim = 3 + num_samples = 100 + x = torch.rand(num_samples, input_dim) + y = torch.sum(x**3, dim=1) + + model = MonoDense(input_dimension=input_dim, output_dimensions=[12, 12, 12, 1], num_increasing=input_dim, num_decreasing=0) + loss_func = torch.nn.MSELoss(reduction='mean') + optimizer = torch.optim.Adam(model.parameters()) + loss_list = [] + + num_epochs = 10000 + for epoch in range(num_epochs): + prediction = model.forward(x).resize_as(y) + loss = loss_func(prediction, y) + utils.backpropagate(optimizer, loss) + loss_list.append(loss.item()) + + assert loss_list[-1] < 0.01 + assert loss_list[1000] < loss_list[0] + assert loss_list[2000] < loss_list[1000] + assert loss_list[3000] < loss_list[2000] diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_overdispersed_binomial_mixture.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_overdispersed_binomial_mixture.py new file mode 100644 index 00000000000..a1f82303c28 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/architecture/test_overdispersed_binomial_mixture.py @@ -0,0 +1,70 @@ +import torch +from torch.distributions.binomial import Binomial +from permutect.architecture.overdispersed_binomial_mixture import OverdispersedBinomialMixture + + +# given a discrete distribution of allele fractions between 0 and 1, and desired depths, generate alt counts, +# fit a BetaBinomialMixture, and compare moments of the underlying Beta mixture (without the binomial part) to +# those of the empirical allele fractions +def test_on_discrete_af_distribution(fractions_1d: torch.Tensor, weights_1d: torch.Tensor, training_depths_1d: torch.Tensor, + num_components: int = 10, num_epochs=1000): + + idx = weights_1d.multinomial(num_samples=len(training_depths_1d), replacement=True) + empirical_fractions_1d = fractions_1d[idx] + + empirical_counts = Binomial(training_depths_1d, empirical_fractions_1d).sample().squeeze() + dummy_input = torch.ones(len(empirical_counts)) + + model = OverdispersedBinomialMixture(input_size=1, num_components=num_components) + model.fit(num_epochs=num_epochs, types_b=dummy_input, depths_1d_tensor=training_depths_1d, + alt_counts_1d_tensor=empirical_counts) + + # moments E[x], E[ln(x)], E[x ln(x)] + model_mean, model_log_mean, model_log_linear_mean = model.moments_of_underlying_beta_mixture(torch.Tensor([1])) + + given_mean = torch.sum(weights_1d * fractions_1d) + given_log_mean = torch.sum(weights_1d * torch.log(fractions_1d)) + given_log_linear_mean = torch.sum(weights_1d * fractions_1d * torch.log(fractions_1d)) + + assert torch.abs(model_mean - given_mean).item() < 0.02 + assert torch.abs(model_log_mean - given_log_mean).item() < 0.1 + assert torch.abs(model_log_linear_mean - given_log_linear_mean).item() < 0.03 + + +def test_single_component(): + num_samples = 1000 + for fraction in [0.05, 0.1, 0.5, 0.75]: + for depth in [10, 100, 1000]: + depths = depth*torch.ones(num_samples).int() + test_on_discrete_af_distribution(fractions_1d=torch.Tensor([fraction]), weights_1d=torch.Tensor([1.0]), training_depths_1d=depths) + + +def test_two_components(): + num_samples = 1000 + depth = 100 + depths = depth * torch.ones(num_samples).int() + for fraction_pair in [(0.1, 0.9), (0.2, 0.4), (0.4, 0.8)]: + for weight_pair in [(0.5, 0.5), (0.25, 0.75), (0.1, 0.9)]: + test_on_discrete_af_distribution(fractions_1d=torch.Tensor(fraction_pair), weights_1d=torch.Tensor(weight_pair), training_depths_1d=depths) + + +def test_three_components(): + num_samples = 1000 + depth = 100 + depths = depth * torch.ones(num_samples).int() + fractions = (0.1, 0.4, 0.7) + for unnormalized_weights in [(1, 1, 1), (1, 4, 1), (1, 5, 9)]: + weights = torch.Tensor(unnormalized_weights) / sum(unnormalized_weights) + test_on_discrete_af_distribution(fractions_1d=torch.Tensor(fractions), weights_1d=weights, training_depths_1d=depths) + + +def test_peak_over_background(): + num_samples = 1000 + depth = 100 + depths = depth * torch.ones(num_samples).int() + fractions = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9) + for unnormalized_weights in [(5, 1, 1, 1, 1, 1, 1, 1, 1), (1, 1, 1, 1, 5, 1, 1, 1, 1)]: + weights = torch.Tensor(unnormalized_weights) / sum(unnormalized_weights) + test_on_discrete_af_distribution(fractions_1d=torch.Tensor(fractions), weights_1d=weights, + training_depths_1d=depths) + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/data/__init__.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/data/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/data/test_base_batch.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/data/test_base_batch.py new file mode 100644 index 00000000000..762402ed873 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/data/test_base_batch.py @@ -0,0 +1,49 @@ +import torch + +import permutect.data.base_datum +from permutect.data.base_datum import BaseDatum + + +# make a three-datum batch +from permutect.utils import Variation, Label + + +def test_base_batch(): + size = 3 + num_gatk_info_features = 5 + + variant_types = [Variation.SNV, Variation.SNV, Variation.INSERTION] + num_read_features = 11 + + # TODO: test different counts and also test that mixed counts fail + ref_counts = [11, 11, 11] + alt_counts = [6, 6, 6] + ref_sequence_strings = ["ACC", "GTG", "TAA"] + + ref_tensors = [torch.rand(n, num_read_features) for n in ref_counts] + alt_tensors = [torch.rand(n, num_read_features) for n in alt_counts] + + gatk_info_tensors = [torch.rand(num_gatk_info_features) for _ in range(size)] + labels = [Label.ARTIFACT, Label.VARIANT, Label.ARTIFACT] + sources = [0,0,0] + + data = [BaseDatum.from_gatk(ref_sequence_strings[n], variant_types[n], ref_tensors[n], alt_tensors[n], gatk_info_tensors[n], labels[n], sources[n]) for n in range(size)] + + batch = permutect.data.base_datum.BaseBatch(data) + + assert torch.equal(batch.get_ref_sequences_2d(), + torch.Tensor([ + [[1,0,0],[0,1,1],[0,0,0],[0,0,0]], + [[0,0,0],[0,0,0],[1,0,1],[0,1,0]], + [[0,1,1],[0,0,0],[0,0,0],[1,0,0]] + ]) + ) + assert batch.size() == 3 + + assert batch.get_reads_2d().shape[0] == sum(ref_counts) + sum(alt_counts) + assert batch.get_reads_2d().shape[1] == num_read_features + + assert batch.get_info_2d().shape[0] == 3 + + assert batch.labels.tolist() == [1.0, 0.0, 1.0] + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/data/test_base_dataset.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/data/test_base_dataset.py new file mode 100644 index 00000000000..d494768d998 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/data/test_base_dataset.py @@ -0,0 +1,93 @@ +import tempfile +import permutect.data.base_dataset as ds +import torch + +from permutect import utils + + +def test_line_to_tensor(): + line1 = "1.0 1.1 1.2 1.3" + tensor1 = ds.line_to_tensor(line1) + + # allow for tensor rounding error + assert torch.max(torch.tensor([1.0, 1.1, 1.2, 1.3]) - tensor1).item() < 0.001 + + +def test_read_integers(): + line1 = "1 2 3 4" + integers = ds.read_integers(line1) + + # allow for tensor rounding error + assert list(integers) == [1, 2, 3, 4] + + +def test_read_2d_tensor(): + tmp = tempfile.NamedTemporaryFile() + + with open(tmp.name, 'w') as f: + lines = [ + "1 2 3 4 5\n", + "6 7 8 9 10\n", + "11 12 13 14 15\n", + "NOTHING\n", + "10 20 30\n", + "40 50 60\n" + ] + + f.writelines(lines) + + with open(tmp.name) as f: + tensor1 = ds.read_2d_tensor(f, 3) + f.readline() + tensor2 = ds.read_2d_tensor(f, 2) + + assert list(tensor1.size()) == [3, 5] + assert list(tensor2.size()) == [2, 3] + + assert tensor1.tolist() == [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]] + + +def test_read_data(): + tmp = tempfile.NamedTemporaryFile() + + with open(tmp.name, 'w') as f: + lines = [ + "UNLABELED\n", + "1:12807, C->T\n", + "GGAGAGGCTTCGATGCCCCTC\n", + "0.192 0.000 0.000 1.000 1.000 1.000 1.000 1.000 1.000\n" + "5 2 10 2\n" + "23 30 1 0 42 21 362 42 320 0 0\n", + "27 32 1 0 9 21 290 9 281 0 0\n", + "24 30 0 0 15 21 353 15 338 0 0\n", + "24 30 0 0 15 21 353 15 338 0 0\n", + "24 30 0 0 15 21 353 15 338 0 0\n", + "24 32 0 0 23 42 351 77 274 0 0\n", + "23 31 0 0 35 30 350 65 285 0 0\n", + "36 4 26 2\n", + "0.879\n", + "4.55\n", + "ARTIFACT\n", + "1:13079, C->G\n", + "CCAGCTGGGTCGACAGACAGG\n", + "0.113 0.045 3.000 0.833 1.000 1.000 1.000 1.000 1.000\n", + "5 1 10 2\n", + "27 24 0 1 8 21 290 281 9 0 0\n", + "27 24 0 1 8 21 290 281 9 0 0\n", + "27 24 0 1 8 21 290 281 9 0 0\n", + "27 24 0 1 8 21 290 281 9 0 0\n", + "27 24 0 1 8 21 290 281 9 0 0\n", + "23 31 1 1 4 21 346 341 5 0 0\n", + "78 4 75 2\n", + "12.5\n", + "-73.3\n" + ] + + f.writelines(lines) + + data = list(ds.read_data(tmp.name)) + assert len(data) == 2 + assert data[0].label == utils.Label.UNLABELED + assert torch.max(data[0].get_info_tensor_1d() - torch.tensor([0.192, 0.000, 0.000, 1.000, 1.000, 1.000, 1.000, 1.000, 1.000] + [1, 0, 0])).item() < 0.001 + + assert data[1].reads_2d.size()[0] == 6 diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/data/test_base_datum.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/data/test_base_datum.py new file mode 100644 index 00000000000..70b10c36ea1 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/data/test_base_datum.py @@ -0,0 +1,31 @@ +import numpy as np +import torch + +from permutect.data import base_datum +from permutect.utils import Variation, Label + + +def test_base_datum(): + num_ref_reads = 6 + num_alt_reads = 8 + num_read_features = 11 + num_info_features = 9 + + ref_tensor = torch.rand(num_ref_reads, num_read_features) + alt_tensor = torch.rand(num_alt_reads, num_read_features) + gatk_info_tensor = torch.rand(num_info_features) + label = Label.ARTIFACT + source = 0 + + snv_datum = base_datum.BaseDatum.from_gatk("AC", Variation.SNV, ref_tensor, alt_tensor, gatk_info_tensor, label, source) + + assert torch.equal(snv_datum.get_ref_sequence_1d(), torch.Tensor([0,1])) + assert torch.equal(snv_datum.reads_2d, np.vstack([ref_tensor, alt_tensor])) + assert snv_datum.get_variant_type() == Variation.SNV + assert snv_datum.label == label + + insertion_datum = base_datum.BaseDatum.from_gatk("GT", Variation.INSERTION, ref_tensor, alt_tensor, gatk_info_tensor, label, source) + deletion_datum = base_datum.BaseDatum.from_gatk("TT", Variation.DELETION, ref_tensor, alt_tensor, gatk_info_tensor, label, source) + assert insertion_datum.get_variant_type() == Variation.INSERTION + assert deletion_datum.get_variant_type() == Variation.DELETION + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/test_utils/__init__.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/test_utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/test_utils/artificial_data.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/test_utils/artificial_data.py new file mode 100644 index 00000000000..091928e7055 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/test_utils/artificial_data.py @@ -0,0 +1,167 @@ +import torch +import random +from permutect.data.base_datum import BaseDatum +from permutect.utils import Variation, Label +from numpy.random import binomial + + +NUM_READ_FEATURES = 5 + + +# random isotropic Gaussian tensor, dilated by different amounts in each dimension +def make_random_tensor(mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + assert mean.size() == std.size() + + # TODO: random normal needs same length as mean + return mean + std * torch.randn(len(mean)) + + +class RandomGATKInfoGenerator: + def __init__(self, mean: torch.Tensor, std: torch.Tensor): + assert len(mean) == len(std) + self.mean = mean + self.std = std + + def generate(self) -> torch.Tensor: + return make_random_tensor(self.mean, self.std) + + +class RandomReadGenerator: + def __init__(self, mean: torch.Tensor, std: torch.Tensor): + assert len(mean) == len(std) + self.mean = mean + self.std = std + + def generate(self, num_reads: int) -> torch.Tensor: + return torch.vstack([make_random_tensor(self.mean, self.std) for _ in range(num_reads)]) + + +def make_random_data(art_gatk_info_gen: RandomGATKInfoGenerator, var_gatk_info_gen: RandomGATKInfoGenerator, art_read_gen: RandomReadGenerator, + var_read_gen: RandomReadGenerator, num_data: int, artifact_fraction=0.5, unlabeled_fraction=0.1, + indel_fraction=0.2, ref_downsampling=10, alt_downsampling=10, is_training_data=True, vaf=0.5, + downsample_variants_to_match_artifacts=True): + data = [] + for _ in range(0, num_data): + position = random.randint(1, 1000000) + + # generate label + artifact = random.uniform(0,1) < artifact_fraction + unlabeled = random.uniform(0,1) < unlabeled_fraction + label = Label.UNLABELED if unlabeled else (Label.ARTIFACT if artifact else Label.VARIANT) + + # generate variant type + indel = random.uniform(0,1) < indel_fraction + variant_type = (Variation.DELETION if random.uniform(0, 1) < 0.5 else Variation.INSERTION) if indel else Variation.SNV + + ref_count = ref_downsampling + # if it's test data and a variant, we draw the alt count from the AF spectrum + # the pd_alt_count is only relevant for variant test data + # we assume artifact used the original alts but variant was downsampled from a het + pd_tumor_depth = 100 + pd_alt_count = random.randint(3, 10) if artifact else binomial(pd_tumor_depth, vaf) + alt_count = random.randint(3, 10) if (artifact or (is_training_data and downsample_variants_to_match_artifacts)) \ + else min(alt_downsampling, pd_alt_count) + + if alt_count == 0: + continue + + gatk_info_tensor = (art_gatk_info_gen if artifact else var_gatk_info_gen).generate() + + ref_tensor = var_read_gen.generate(ref_count) + alt_tensor = (art_read_gen if artifact else var_read_gen).generate(alt_count) + + # TODO: vary the reference sequence string? + data.append(BaseDatum.from_gatk("GTAAAGT", variant_type, ref_tensor, alt_tensor, gatk_info_tensor, label)) + + return data + + +# artifacts and variants are identical except 0th component of artifact read tensors all have the same sign, whereas +# each non-artifact read is randomly + or - +def make_random_strand_bias_data(num_data: int, artifact_fraction=0.5, unlabeled_fraction=0.1, + ref_downsampling=10, alt_downsampling=10, is_training_data=True, vaf=0.5, num_gatk_info_features=5): + data = [] + for _ in range(0, num_data): + # generate label + artifact = random.uniform(0,1) < artifact_fraction + unlabeled = random.uniform(0, 1) < unlabeled_fraction + label = Label.UNLABELED if unlabeled else (Label.ARTIFACT if artifact else Label.VARIANT) + + # if it's test data and a variant, we draw the alt count from the AF spectrum + # the alt_count is only relevant for variant test data + # we assume artifact used the original alts but variant was downsampled from a het + depth = 100 + alt_count = random.randint(3, 10) if artifact else binomial(depth, vaf) + alt_tensor_size = random.randint(3, 10) if (artifact or is_training_data) else min(alt_downsampling, alt_count) + + if alt_count == 0: + continue + + gatk_info_tensor = torch.zeros(num_gatk_info_features) + + # before modifying the 0th element, it's all uniform Gaussian data + ref_tensor = torch.randn(ref_downsampling, NUM_READ_FEATURES) + alt_tensor = torch.randn(alt_tensor_size, NUM_READ_FEATURES) + + if artifact: + sign = 1 if random.uniform(0,1) < 0.5 else -1 + alt_tensor[:, 0] = sign * torch.abs(alt_tensor[:, 0]) + + # TODO: vary the reference sequence string? + data.append(BaseDatum.from_gatk("TGGGAATG", Variation.SNV, ref_tensor, alt_tensor, gatk_info_tensor, label)) + + return data + + +# good and bad data are generated by distinct gaussians +def make_two_gaussian_data(num_data, is_training_data=True, vaf=0.5, artifact_fraction=0.5, unlabeled_fraction=0.1, + indel_fraction=0.2, ref_downsampling=10, alt_downsampling=10, downsample_variants_to_match_artifacts=True): + var_gatk_info_mean = torch.tensor([-1]*9) + var_gatk_info_std = torch.tensor([1]*9) + art_gatk_info_mean = torch.tensor([1] * 9) + art_gatk_info_std = torch.tensor([1] * 9) + + var_gatk_info_gen = RandomGATKInfoGenerator(var_gatk_info_mean, var_gatk_info_std) + art_gatk_info_gen = RandomGATKInfoGenerator(art_gatk_info_mean, art_gatk_info_std) + + var_read_mean = torch.tensor([-1] * 11) + var_read_std = torch.tensor([1] * 11) + art_read_mean = torch.tensor([1] * 11) + art_read_std = torch.tensor([1] * 11) + + var_read_gen = RandomReadGenerator(var_read_mean, var_read_std) + art_read_gen = RandomReadGenerator(art_read_mean, art_read_std) + + return make_random_data(art_gatk_info_gen=art_gatk_info_gen, var_gatk_info_gen=var_gatk_info_gen, art_read_gen=art_read_gen, + var_read_gen=var_read_gen, num_data=num_data, is_training_data=is_training_data, vaf=vaf, + artifact_fraction=artifact_fraction, unlabeled_fraction=unlabeled_fraction, + indel_fraction=indel_fraction, ref_downsampling=ref_downsampling, alt_downsampling=alt_downsampling, + downsample_variants_to_match_artifacts=downsample_variants_to_match_artifacts) + + +# good and bad data are generated by gaussians with same mean (0) but artifacts are much more spread out +def make_wide_and_narrow_gaussian_data(num_data, is_training_data=True, vaf=0.5, artifact_fraction=0.5, unlabeled_fraction=0.1, + indel_fraction=0.2, ref_downsampling=10, alt_downsampling=10, downsample_variants_to_match_artifacts=True): + var_gatk_info_mean = torch.tensor([0]*9) + var_gatk_info_std = torch.tensor([1]*9) + art_gatk_info_mean = torch.tensor([0] * 9) + art_gatk_info_std = torch.tensor([2] * 9) + + var_gatk_info_gen = RandomGATKInfoGenerator(var_gatk_info_mean, var_gatk_info_std) + art_gatk_info_gen = RandomGATKInfoGenerator(art_gatk_info_mean, art_gatk_info_std) + + var_read_mean = torch.tensor([0] * 11) + var_read_std = torch.tensor([1] * 11) + art_read_mean = torch.tensor([0] * 11) + art_read_std = torch.tensor([2] * 11) + + var_read_gen = RandomReadGenerator(var_read_mean, var_read_std) + art_read_gen = RandomReadGenerator(art_read_mean, art_read_std) + + return make_random_data(art_gatk_info_gen=art_gatk_info_gen, var_gatk_info_gen=var_gatk_info_gen, art_read_gen=art_read_gen, + var_read_gen=var_read_gen, num_data=num_data, is_training_data=is_training_data, vaf=vaf, + artifact_fraction=artifact_fraction, unlabeled_fraction=unlabeled_fraction, + indel_fraction=indel_fraction, ref_downsampling=ref_downsampling, + alt_downsampling=alt_downsampling, downsample_variants_to_match_artifacts=downsample_variants_to_match_artifacts) + + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/__init__.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_filter_variants.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_filter_variants.py new file mode 100644 index 00000000000..3bf47fe57ae --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_filter_variants.py @@ -0,0 +1,43 @@ +import tempfile +from argparse import Namespace + +from permutect import constants +from permutect.tools import filter_variants + + +def test_filtering_on_dream1_chr20(): + # Inputs + artifact_model = '/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/artifact-model.pt' + + mutect2_vcf = '/Users/davidben/mutect3/permutect/integration-tests/dream1-chr20/mutect2_chr20.vcf' + maf_segments = '/Users/davidben/mutect3/permutect/integration-tests/dream1-chr20/segments.table' + contigs_table = '/Users/davidben/mutect3/permutect/integration-tests/dream1-chr20/contigs.table' + filtering_dataset = '/Users/davidben/mutect3/permutect/integration-tests/dream1-chr20/test_chr20.dataset' + + # Outputs + permutect_vcf = tempfile.NamedTemporaryFile() + tensorboard_dir = tempfile.TemporaryDirectory() + + filtering_args = Namespace() + setattr(filtering_args, constants.INPUT_NAME, mutect2_vcf) + setattr(filtering_args, constants.TEST_DATASET_NAME, filtering_dataset) + setattr(filtering_args, constants.M3_MODEL_NAME, artifact_model) + setattr(filtering_args, constants.OUTPUT_NAME, permutect_vcf.name) + setattr(filtering_args, constants.TENSORBOARD_DIR_NAME, tensorboard_dir.name) + setattr(filtering_args, constants.BATCH_SIZE_NAME, 64) + setattr(filtering_args, constants.NUM_WORKERS_NAME, 0) + setattr(filtering_args, constants.CHUNK_SIZE_NAME, 100000) + setattr(filtering_args, constants.NUM_SPECTRUM_ITERATIONS_NAME, 2) + setattr(filtering_args, constants.HET_BETA_NAME, 10) + setattr(filtering_args, constants.SPECTRUM_LEARNING_RATE_NAME, 0.001) + setattr(filtering_args, constants.INITIAL_LOG_VARIANT_PRIOR_NAME, -10.0) + setattr(filtering_args, constants.INITIAL_LOG_ARTIFACT_PRIOR_NAME, -10.0) + setattr(filtering_args, constants.GENOMIC_SPAN_NAME, 60000000) + setattr(filtering_args, constants.MAF_SEGMENTS_NAME, None) + setattr(filtering_args, constants.CONTIGS_TABLE_NAME, contigs_table) + setattr(filtering_args, constants.NORMAL_MAF_SEGMENTS_NAME, None) + setattr(filtering_args, constants.GERMLINE_MODE_NAME, False) + setattr(filtering_args, constants.NO_GERMLINE_MODE_NAME, False) + + filter_variants.main_without_parsing(filtering_args) + h = 9 \ No newline at end of file diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_pipeline.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_pipeline.py new file mode 100644 index 00000000000..618dddfbcc8 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_pipeline.py @@ -0,0 +1,90 @@ +from argparse import Namespace +import tempfile + +from permutect.tools import preprocess_dataset, train_model, filter_variants +from permutect import constants + + +def test_on_dream1(): + # Input Files + training_datasets = ["/Users/davidben/permutect/just-dream-1/dream1-normal-small-training.dataset"] + #mutect2_vcf = "/Users/davidben/permutect/dream-vcfs/dream1-50000.vcf" + mutect2_vcf = "/Users/davidben/permutect/integration-test/dream1-mutect2-small.vcf" + #filtering_dataset = "/Users/davidben/permutect/just-dream-1/dream1-test.dataset" + filtering_dataset = "/Users/davidben/permutect/integration-test/dream1-small-test.dataset" + + # Intermediate and Output Files + training_data_tarfile = tempfile.NamedTemporaryFile() + saved_artifact_model = tempfile.NamedTemporaryFile() + training_tensorboard_dir = tempfile.TemporaryDirectory() + filtering_tensorboard_dir = tempfile.TemporaryDirectory() + filtered_mutect3_vcf = tempfile.NamedTemporaryFile() + + # STEP 1: preprocess the plain text training dataset yielding a training tarfile + preprocess_args = Namespace() + setattr(preprocess_args, constants.CHUNK_SIZE_NAME, 1e6) + setattr(preprocess_args, constants.TRAINING_DATASETS_NAME, training_datasets) + setattr(preprocess_args, constants.OUTPUT_NAME, training_data_tarfile.name) + setattr(preprocess_args, constants.SOURCES_NAME, [0]) + preprocess_dataset.main_without_parsing(preprocess_args) + + # STEP 2: train a model + train_model_args = Namespace() + setattr(train_model_args, constants.READ_LAYERS_NAME, [10, 10, 10]) + setattr(train_model_args, constants.SELF_ATTENTION_HIDDEN_DIMENSION_NAME, 20) + setattr(train_model_args, constants.NUM_SELF_ATTENTION_LAYERS_NAME, 3) + setattr(train_model_args, constants.INFO_LAYERS_NAME, [30, 30, 30]) + setattr(train_model_args, constants.AGGREGATION_LAYERS_NAME, [30, 30, 30, 30]) + setattr(train_model_args, constants.CALIBRATION_LAYERS_NAME, [6,6]) + cnn_layer_strings = ['convolution/kernel_size=3/out_channels=64', + 'pool/kernel_size=2', + 'leaky_relu', + 'convolution/kernel_size=3/dilation=2/out_channels=5', + 'leaky_relu', + 'flatten', + 'linear/out_features=10'] + setattr(train_model_args, constants.REF_SEQ_LAYER_STRINGS_NAME, cnn_layer_strings) + setattr(train_model_args, constants.DROPOUT_P_NAME, 0.0) + setattr(train_model_args, constants.LEARNING_RATE_NAME, 0.001) + setattr(train_model_args, constants.WEIGHT_DECAY_NAME, 0.01) + setattr(train_model_args, constants.BATCH_NORMALIZE_NAME, False) + setattr(train_model_args, constants.LEARN_ARTIFACT_SPECTRA_NAME, True) # could go either way + setattr(train_model_args, constants.GENOMIC_SPAN_NAME, 100000) + + # Training data inputs + setattr(train_model_args, constants.TRAIN_TAR_NAME, training_data_tarfile.name) + + # training hyperparameters + setattr(train_model_args, constants.REWEIGHTING_RANGE_NAME, 0.3) + setattr(train_model_args, constants.BATCH_SIZE_NAME, 64) + setattr(train_model_args, constants.NUM_WORKERS_NAME, 2) + setattr(train_model_args, constants.NUM_EPOCHS_NAME, 2) + setattr(train_model_args, constants.NUM_CALIBRATION_EPOCHS_NAME, 1) + setattr(train_model_args, constants.NUM_REFLESS_EPOCHS_NAME, 2) + + # path to saved model + setattr(train_model_args, constants.OUTPUT_NAME, saved_artifact_model.name) + setattr(train_model_args, constants.TENSORBOARD_DIR_NAME, training_tensorboard_dir.name) + + train_model.main_without_parsing(train_model_args) + + # STEP 3: call variants + filtering_args = Namespace() + setattr(filtering_args, constants.INPUT_NAME, mutect2_vcf) + setattr(filtering_args, constants.TEST_DATASET_NAME, filtering_dataset) + setattr(filtering_args, constants.M3_MODEL_NAME, saved_artifact_model.name) + setattr(filtering_args, constants.OUTPUT_NAME, filtered_mutect3_vcf.name) + setattr(filtering_args, constants.TENSORBOARD_DIR_NAME, filtering_tensorboard_dir.name) + setattr(filtering_args, constants.BATCH_SIZE_NAME, 64) + setattr(filtering_args, constants.CHUNK_SIZE_NAME, 100000) + setattr(filtering_args, constants.NUM_SPECTRUM_ITERATIONS_NAME, 10) + setattr(filtering_args, constants.INITIAL_LOG_VARIANT_PRIOR_NAME, -10.0) + setattr(filtering_args, constants.INITIAL_LOG_ARTIFACT_PRIOR_NAME, -10.0) + setattr(filtering_args, constants.GENOMIC_SPAN_NAME, 100000) + setattr(filtering_args, constants.MAF_SEGMENTS_NAME, None) + setattr(filtering_args, constants.NORMAL_MAF_SEGMENTS_NAME, None) + setattr(filtering_args, constants.GERMLINE_MODE_NAME, False) + setattr(filtering_args, constants.NO_GERMLINE_MODE_NAME, False) + + filter_variants.main_without_parsing(filtering_args) + h = 9 diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_preprocess_dataset.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_preprocess_dataset.py new file mode 100644 index 00000000000..435f0f62502 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_preprocess_dataset.py @@ -0,0 +1,27 @@ +from argparse import Namespace +import tempfile + +from permutect.data import base_datum +from permutect.data.base_dataset import BaseDataset +from permutect.tools import preprocess_dataset +from permutect import constants, utils +from permutect.utils import extract_to_temp_dir + + +def test_on_10_megabases_singular(): + training_datasets = ["/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/training-dataset.txt"] + training_data_tarfile = tempfile.NamedTemporaryFile() + + preprocess_args = Namespace() + setattr(preprocess_args, constants.CHUNK_SIZE_NAME, 1e6) + setattr(preprocess_args, constants.TRAINING_DATASETS_NAME, training_datasets) + setattr(preprocess_args, constants.OUTPUT_NAME, training_data_tarfile.name) + setattr(preprocess_args, constants.SOURCES_NAME, [0]) + preprocess_dataset.main_without_parsing(preprocess_args) + + with tempfile.TemporaryDirectory() as train_temp_dir: + training_files = extract_to_temp_dir(training_data_tarfile.name, train_temp_dir) + for training_file in training_files: + base_data_list = base_datum.load_list_of_base_data(training_file) + + dataset = BaseDataset(data_tarfile=training_data_tarfile.name, num_folds=10) \ No newline at end of file diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_prune_dataset.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_prune_dataset.py new file mode 100644 index 00000000000..0971d0ad780 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_prune_dataset.py @@ -0,0 +1,53 @@ +import tempfile +from argparse import Namespace + +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator +from permutect import constants +from permutect.data.base_dataset import BaseDataset +from permutect.tools import prune_dataset + + +def test_prune_dataset(): + # Inputs + training_data_tarfile = '/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/preprocessed-dataset.tar' + base_model = '/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/base-model.pt' + + # Outputs + pruned_dataset = tempfile.NamedTemporaryFile() + training_tensorboard_dir = tempfile.TemporaryDirectory() + + # STEP 2: train a model + prune_dataset_args = Namespace() + setattr(prune_dataset_args, constants.AGGREGATION_LAYERS_NAME, [20, 20, 20]) + setattr(prune_dataset_args, constants.CALIBRATION_LAYERS_NAME, [6,6]) + setattr(prune_dataset_args, constants.DROPOUT_P_NAME, 0.0) + setattr(prune_dataset_args, constants.BATCH_NORMALIZE_NAME, False) + setattr(prune_dataset_args, constants.LEARN_ARTIFACT_SPECTRA_NAME, True) # could go either way + setattr(prune_dataset_args, constants.GENOMIC_SPAN_NAME, 100000) + + # Training data inputs + setattr(prune_dataset_args, constants.TRAIN_TAR_NAME, training_data_tarfile) + setattr(prune_dataset_args, constants.BASE_MODEL_NAME, base_model) + + setattr(prune_dataset_args, constants.CHUNK_SIZE_NAME, 2e9) + + # training hyperparameters + setattr(prune_dataset_args, constants.BATCH_SIZE_NAME, 64) + setattr(prune_dataset_args, constants.NUM_WORKERS_NAME, 2) + setattr(prune_dataset_args, constants.NUM_EPOCHS_NAME, 2) + setattr(prune_dataset_args, constants.NUM_CALIBRATION_EPOCHS_NAME, 1) + setattr(prune_dataset_args, constants.LEARNING_RATE_NAME, 0.001) + setattr(prune_dataset_args, constants.WEIGHT_DECAY_NAME, 0.01) + + # path to saved model + setattr(prune_dataset_args, constants.OUTPUT_NAME, pruned_dataset.name) + setattr(prune_dataset_args, constants.TENSORBOARD_DIR_NAME, training_tensorboard_dir.name) + + prune_dataset.main_without_parsing(prune_dataset_args) + + events = EventAccumulator(training_tensorboard_dir.name) + events.Reload() + + pruned_base_dataset = BaseDataset(data_tarfile=pruned_dataset, num_folds=10) + h = 99 + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_train_base_model.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_train_base_model.py new file mode 100644 index 00000000000..2707abc810b --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_train_base_model.py @@ -0,0 +1,56 @@ +import tempfile +from argparse import Namespace + +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + +from permutect import constants +from permutect.tools import train_base_model +from permutect.architecture.base_model import load_base_model + + +def test_train_base_model(): + training_data_tarfile = "/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/preprocessed-dataset.tar" + saved_base_model = tempfile.NamedTemporaryFile() + training_tensorboard_dir = tempfile.TemporaryDirectory() + + train_model_args = Namespace() + setattr(train_model_args, constants.READ_LAYERS_NAME, [10, 10, 10]) + setattr(train_model_args, constants.SELF_ATTENTION_HIDDEN_DIMENSION_NAME, 20) + setattr(train_model_args, constants.NUM_SELF_ATTENTION_LAYERS_NAME, 2) + setattr(train_model_args, constants.INFO_LAYERS_NAME, [10, 10]) + setattr(train_model_args, constants.AGGREGATION_LAYERS_NAME, [20, 20, 20]) + cnn_layer_strings = ['convolution/kernel_size=3/out_channels=64', + 'pool/kernel_size=2', + 'leaky_relu', + 'flatten', + 'linear/out_features=10'] + setattr(train_model_args, constants.REF_SEQ_LAYER_STRINGS_NAME, cnn_layer_strings) + setattr(train_model_args, constants.DROPOUT_P_NAME, 0.0) + setattr(train_model_args, constants.BATCH_NORMALIZE_NAME, False) + + setattr(train_model_args, constants.LEARNING_METHOD_NAME, 'SEMISUPERVISED') + + # Training data inputs + setattr(train_model_args, constants.TRAIN_TAR_NAME, training_data_tarfile) + setattr(train_model_args, constants.PRETRAINED_MODEL_NAME, None) + + # training hyperparameters + setattr(train_model_args, constants.REWEIGHTING_RANGE_NAME, 0.3) + setattr(train_model_args, constants.BATCH_SIZE_NAME, 64) + setattr(train_model_args, constants.INFERENCE_BATCH_SIZE_NAME, 64) + setattr(train_model_args, constants.NUM_WORKERS_NAME, 0) + setattr(train_model_args, constants.NUM_EPOCHS_NAME, 2) + setattr(train_model_args, constants.NUM_CALIBRATION_EPOCHS_NAME, 0) + setattr(train_model_args, constants.LEARNING_RATE_NAME, 0.001) + setattr(train_model_args, constants.WEIGHT_DECAY_NAME, 0.01) + + # path to saved model + setattr(train_model_args, constants.OUTPUT_NAME, saved_base_model.name) + setattr(train_model_args, constants.TENSORBOARD_DIR_NAME, training_tensorboard_dir.name) + + train_base_model.main_without_parsing(train_model_args) + + events = EventAccumulator(training_tensorboard_dir.name) + events.Reload() + + loaded_base_model = load_base_model(saved_base_model) diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_train_model.py b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_train_model.py new file mode 100644 index 00000000000..89a09127f95 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/test/tools/test_train_model.py @@ -0,0 +1,57 @@ +import tempfile +from argparse import Namespace + +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator +from permutect import constants, utils +from permutect.tools import train_model +from permutect.architecture.artifact_model import load_artifact_model + + +def test_train_model(): + # Inputs + training_data_tarfile = '/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/preprocessed-dataset.tar' + base_model = '/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/base-model.pt' + + # Outputs + saved_artifact_model = tempfile.NamedTemporaryFile() + training_tensorboard_dir = tempfile.TemporaryDirectory() + + # STEP 2: train a model + train_model_args = Namespace() + setattr(train_model_args, constants.AGGREGATION_LAYERS_NAME, [20, 20, 20]) + setattr(train_model_args, constants.CALIBRATION_LAYERS_NAME, [6,6]) + setattr(train_model_args, constants.DROPOUT_P_NAME, 0.0) + setattr(train_model_args, constants.BATCH_NORMALIZE_NAME, False) + setattr(train_model_args, constants.LEARN_ARTIFACT_SPECTRA_NAME, True) # could go either way + setattr(train_model_args, constants.GENOMIC_SPAN_NAME, 100000) + + # Training data inputs + setattr(train_model_args, constants.TRAIN_TAR_NAME, training_data_tarfile) + setattr(train_model_args, constants.BASE_MODEL_NAME, base_model) + + # training hyperparameters + setattr(train_model_args, constants.BATCH_SIZE_NAME, 64) + setattr(train_model_args, constants.INFERENCE_BATCH_SIZE_NAME, 64) + setattr(train_model_args, constants.NUM_WORKERS_NAME, 0) + setattr(train_model_args, constants.NUM_EPOCHS_NAME, 2) + setattr(train_model_args, constants.NUM_CALIBRATION_EPOCHS_NAME, 1) + setattr(train_model_args, constants.LEARNING_RATE_NAME, 0.001) + setattr(train_model_args, constants.WEIGHT_DECAY_NAME, 0.01) + + # path to saved model + setattr(train_model_args, constants.OUTPUT_NAME, saved_artifact_model.name) + setattr(train_model_args, constants.TENSORBOARD_DIR_NAME, training_tensorboard_dir.name) + + train_model.main_without_parsing(train_model_args) + + events = EventAccumulator(training_tensorboard_dir.name) + events.Reload() + + device = utils.gpu_if_available() + loaded_artifact_model, artifact_log_priors, artifact_spectra_state_dict = load_artifact_model(saved_artifact_model, device=device) + assert artifact_log_priors is not None + assert artifact_spectra_state_dict is not None + + print(artifact_log_priors) + h = 99 + diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/tools/__init__.py b/src/main/python/org/broadinstitute/hellbender/permutect/tools/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/tools/edit_dataset.py b/src/main/python/org/broadinstitute/hellbender/permutect/tools/edit_dataset.py new file mode 100644 index 00000000000..5e901de8df0 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/tools/edit_dataset.py @@ -0,0 +1,129 @@ +import argparse +import os +import tarfile +import tempfile +from enum import Enum + +import psutil +import torch.utils.data + +from permutect.data import base_datum +from tqdm.autonotebook import tqdm + +from permutect import constants +from permutect.data.base_dataset import BaseDataset +from permutect.utils import Label + + +class EditType(Enum): + UNLABEL_ARTIFACTS = "unlabel_artifacts" + UNLABEL_VARIANTS = "unlabel_variants" + UNLABEL_EVERYTHING = "unlabel_everything" + REMOVE_ARTIFACTS = "remove_artifacts" + REMOVE_VARIANTS = "remove_variants" + KEEP_EVERYTHING = "keep_everything" + + +# generates BaseDatum(s) from the original dataset that *pass* the pruning thresholds +def generate_edited_data(base_datasets, edit_type: str, source: int): + pbar = tqdm(enumerate(torch.utils.data.ConcatDataset(base_datasets)), mininterval=60) + + for n, base_datum in pbar: + if source is not None: + base_datum.set_source(source) + + if edit_type == EditType.UNLABEL_ARTIFACTS.value: + if base_datum.label == Label.ARTIFACT: + base_datum.set_label(Label.UNLABELED) + yield base_datum + elif edit_type == EditType.UNLABEL_VARIANTS.value: + if base_datum.label == Label.VARIANT: + base_datum.set_label(Label.UNLABELED) + yield base_datum + elif edit_type == EditType.UNLABEL_EVERYTHING.value: + base_datum.set_label(Label.UNLABELED) + yield base_datum + elif edit_type == EditType.REMOVE_ARTIFACTS.value: + if base_datum.label != Label.ARTIFACT: + yield base_datum + elif edit_type == EditType.REMOVE_VARIANTS.value: + if base_datum.label != Label.VARIANT: + yield base_datum + elif edit_type == EditType.KEEP_EVERYTHING.value: + yield base_datum + else: + raise Exception(f"edit type {edit_type} not implemented yet") + + +# takes a ReadSet generator and organizes into buffers. +# TODO: probably code duplication since the generator is already pruned +def generate_output_data_buffers(output_data_generator, max_bytes_per_chunk: int): + buffer, bytes_in_buffer = [], 0 + for datum in output_data_generator: + + buffer.append(datum) + bytes_in_buffer += datum.size_in_bytes() + if bytes_in_buffer > max_bytes_per_chunk: + print(f"Memory usage percent: {psutil.virtual_memory().percent:.1f}") + print(f"{bytes_in_buffer} bytes in chunk") + yield buffer + buffer, bytes_in_buffer = [], 0 + + # There will be some data left over, in general. + if buffer: + yield buffer + + +def make_output_training_dataset(pruned_data_buffer_generator, output_tarfile): + pruned_data_files = [] + for base_data_list in pruned_data_buffer_generator: + with tempfile.NamedTemporaryFile(delete=False) as train_data_file: + base_datum.save_list_base_data(base_data_list, train_data_file) + pruned_data_files.append(train_data_file.name) + + # bundle them in a tarfile + with tarfile.open(output_tarfile, "w") as train_tar: + for train_file in pruned_data_files: + train_tar.add(train_file, arcname=os.path.basename(train_file)) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='train the Mutect3 artifact model') + parser.add_argument('--' + constants.CHUNK_SIZE_NAME, type=int, default=int(2e9), required=False, + help='size in bytes of output binary data files') + parser.add_argument('--' + constants.DATASET_EDIT_TYPE_NAME, type=str, required=True, + help='how to modify the dataset') + parser.add_argument('--' + constants.SOURCE_NAME, type=int, required=False, help='new source integer to apply') + + # input / output + parser.add_argument('--' + constants.TRAIN_TAR_NAME, nargs='+', type=str, required=True, + help='tarfile(s) of training/validation datasets produced by preprocess_dataset.py') + parser.add_argument('--' + constants.OUTPUT_NAME, type=str, required=True, help='path to pruned dataset file') + + return parser.parse_args() + + +def main_without_parsing(args): + original_tarfiles = getattr(args, constants.TRAIN_TAR_NAME) # list of files + output_tarfile = getattr(args, constants.OUTPUT_NAME) + chunk_size = getattr(args, constants.CHUNK_SIZE_NAME) + edit_type = getattr(args, constants.DATASET_EDIT_TYPE_NAME) + new_source = getattr(args, constants.SOURCE_NAME) + base_datasets = map(lambda original_tarfile: BaseDataset(data_tarfile=original_tarfile), original_tarfiles) + + # generate ReadSets + output_data_generator = generate_edited_data(base_datasets, edit_type, new_source) + + # generate List[ReadSet]s + output_data_buffer_generator = generate_output_data_buffers(output_data_generator, chunk_size) + + make_output_training_dataset(output_data_buffer_generator, output_tarfile=output_tarfile) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/tools/evaluate_model.py b/src/main/python/org/broadinstitute/hellbender/permutect/tools/evaluate_model.py new file mode 100644 index 00000000000..2631418d385 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/tools/evaluate_model.py @@ -0,0 +1,42 @@ +import argparse +from torch.utils.tensorboard import SummaryWriter +from permutect import constants, utils +from permutect.data.base_dataset import BaseDataset +from permutect.architecture.artifact_model import load_artifact_model + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--' + constants.EVALUATION_TAR_NAME, type=str, required=True, + help='tarfile of evaluation datasets produced by preprocess_dataset.py') + parser.add_argument('--' + constants.M3_MODEL_NAME, required=True, help='trained Mutect3 artifact model from train_model.py') + parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, help='path to output tensorboard') + parser.add_argument('--' + constants.BATCH_SIZE_NAME, type=int, default=64, required=False, help='batch size') + parser.add_argument('--' + constants.NUM_WORKERS_NAME, type=int, default=0, required=False, + help='number of subprocesses devoted to data loading, which includes reading from memory map, ' + 'collating batches, and transferring to GPU.') + + return parser.parse_args() + + +def main_without_parsing(args): + data_tarfile = getattr(args, constants.EVALUATION_TAR_NAME) + saved_artifact_model = getattr(args, constants.M3_MODEL_NAME) + tensorboard_dir = getattr(args, constants.TENSORBOARD_DIR_NAME) + batch_size = getattr(args, constants.BATCH_SIZE_NAME) + num_workers = getattr(args, constants.NUM_WORKERS_NAME) + + dataset = BaseDataset(data_tarfile=data_tarfile, num_folds=10) + device = utils.gpu_if_available() + artifact_model, _, _ = load_artifact_model(saved_artifact_model, device=device) + summary_writer = SummaryWriter(tensorboard_dir) + artifact_model.evaluate_model_after_training(dataset, batch_size, num_workers, summary_writer) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/tools/filter_variants.py b/src/main/python/org/broadinstitute/hellbender/permutect/tools/filter_variants.py new file mode 100644 index 00000000000..9a50c2ad1fa --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/tools/filter_variants.py @@ -0,0 +1,399 @@ +import argparse +import math +from collections import defaultdict +from typing import Set + +import cyvcf2 +import psutil +import torch +from intervaltree import IntervalTree +from torch.utils.tensorboard import SummaryWriter +from tqdm.autonotebook import tqdm + +from permutect import constants, utils +from permutect.architecture.artifact_model import ArtifactModel, load_base_model_and_artifact_model +from permutect.architecture.posterior_model import PosteriorModel +from permutect.architecture.base_model import BaseModel +from permutect.data import base_dataset, plain_text_data, base_datum +from permutect.data.base_datum import Variant +from permutect.data.posterior import PosteriorDataset, PosteriorDatum, PosteriorBatch +from permutect.data.artifact_dataset import ArtifactDataset +from permutect.metrics.evaluation_metrics import EvaluationMetrics, PosteriorResult, EmbeddingMetrics, \ + round_up_to_nearest_three, MAX_COUNT +from permutect.utils import Call, find_variant_type, Label, Variation, Epoch, trim_alleles_on_right + +TRUSTED_M2_FILTERS = {'contamination'} + +POST_PROB_INFO_KEY = 'POST' +ARTIFACT_LOD_INFO_KEY = 'ARTLOD' +LOG_PRIOR_INFO_KEY = 'PRIOR' +SPECTRA_LOG_LIKELIHOOD_INFO_KEY = 'SPECLL' +NORMAL_LOG_LIKELIHOOD_INFO_KEY = 'NORMLL' + +FILTER_NAMES = [call_type.name.lower() for call_type in Call] + + +# the inverse of the sigmoid function. Convert a probability to a logit. +def prob_to_logit(prob: float): + clipped_prob = 0.5 + 0.9999999 * (prob - 0.5) + return math.log(clipped_prob / (1 - clipped_prob)) + + +def get_first_numeric_element(variant, key): + tuple_or_scalar = variant.INFO[key] + return tuple_or_scalar[0] if type(tuple_or_scalar) is tuple else tuple_or_scalar + + +# if alt and ref alleles are not in minimal representation ie have redundant matching bases at the end, trim them + + +# TODO: contigs stored as integer index must be converted back to string to compare VCF variants with dataset variants!!! +def encode(contig: str, position: int, ref: str, alt: str): + trimmed_ref, trimmed_alt = trim_alleles_on_right(ref, alt) + return contig + ':' + str(position) + ':' + base_datum.truncate_bases_if_necessary(trimmed_alt) + + +def encode_datum(variant: Variant, contig_index_to_name_map): + contig_name = contig_index_to_name_map[variant.contig] + return encode(contig_name, variant.position, variant.ref, variant.alt) + + +def encode_variant(v: cyvcf2.Variant, zero_based=False): + alt = v.ALT[0] # TODO: we're assuming biallelic + ref = v.REF + start = (v.start + 1) if zero_based else v.start + return encode(v.CHROM, start, ref, alt) + + +def filters_to_keep_from_m2(v: cyvcf2.Variant) -> Set[str]: + return set([]) if v.FILTER is None else set(v.FILTER.split(";")).intersection(TRUSTED_M2_FILTERS) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--' + constants.INPUT_NAME, required=True, help='unfiltered input Mutect2 VCF') + parser.add_argument('--' + constants.TEST_DATASET_NAME, required=True, + help='plain text dataset file corresponding to variants in input VCF') + parser.add_argument('--' + constants.M3_MODEL_NAME, required=True, help='trained Permutect model from train_model.py') + parser.add_argument('--' + constants.CONTIGS_TABLE_NAME, required=True, help='table of contig names vs integer indices') + parser.add_argument('--' + constants.OUTPUT_NAME, required=True, help='path to output filtered VCF') + parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, help='path to output tensorboard') + parser.add_argument('--' + constants.BATCH_SIZE_NAME, type=int, default=64, required=False, help='batch size') + parser.add_argument('--' + constants.NUM_WORKERS_NAME, type=int, default=0, required=False, + help='number of subprocesses devoted to data loading, which includes reading from memory map, ' + 'collating batches, and transferring to GPU.') + parser.add_argument('--' + constants.CHUNK_SIZE_NAME, type=int, default=100000, required=False, help='size in bytes of intermediate binary datasets') + parser.add_argument('--' + constants.NUM_SPECTRUM_ITERATIONS_NAME, type=int, default=10, required=False, + help='number of epochs for fitting allele fraction spectra') + parser.add_argument('--' + constants.SPECTRUM_LEARNING_RATE_NAME, type=float, default=0.001, required=False, + help='learning rate for fitting allele fraction spectra') + parser.add_argument('--' + constants.INITIAL_LOG_VARIANT_PRIOR_NAME, type=float, default=-10.0, required=False, + help='initial value for natural log prior of somatic variants') + parser.add_argument('--' + constants.INITIAL_LOG_ARTIFACT_PRIOR_NAME, type=float, default=-10.0, required=False, + help='initial value for natural log prior of artifacts') + parser.add_argument('--' + constants.GENOMIC_SPAN_NAME, type=float, required=True, + help='number of sites considered by Mutect2, including those lacking variation or artifacts, hence absent from input dataset. ' + 'Necessary for learning priors since otherwise rates of artifacts and variants would be overinflated.') + parser.add_argument('--' + constants.MAF_SEGMENTS_NAME, required=False, + help='copy-number segmentation file from GATK containing minor allele fractions. ' + 'Useful for modeling germline variation as the minor allele fraction determines the distribution of germline allele counts.') + parser.add_argument('--' + constants.NORMAL_MAF_SEGMENTS_NAME, required=False, + help='copy-number segmentation file from GATK containing minor allele fractions in the normal/control sample') + + parser.add_argument('--' + constants.GERMLINE_MODE_NAME, action='store_true', + help='flag for genotyping both somatic and somatic variants distinctly but considering both ' + 'as non-errors (true positives), which affects the posterior threshold set by optimal F1 score') + parser.add_argument('--' + constants.HET_BETA_NAME, type=float, required=False, + help='beta shape parameter for germline spectrum beta binomial if we want to override binomial') + + parser.add_argument('--' + constants.NO_GERMLINE_MODE_NAME, action='store_true', + help='flag for not genotyping germline events so that the only possibilities considered are ' + 'somatic, artifact, and sequencing error. This is useful for certain validation where ' + 'pseudo-somatic events are created by mixing germline events at varying fractions') + return parser.parse_args() + + +def get_segmentation(segments_file) -> defaultdict: + + result = defaultdict(IntervalTree) + if segments_file is None: + return result + + print("reading segmentation file") + with open(segments_file, 'r') as file: + for line in file: + if line.startswith("#") or (line.startswith("contig") and "minor_allele_fraction" in line): + continue + tokens = line.split() + contig, start, stop, maf = tokens[0], int(tokens[1]), int(tokens[2]), float(tokens[3]) + if stop > start: # IntervalTree throws error if start == stop + result[contig][start:stop] = maf + + return result + + +def main_without_parsing(args): + make_filtered_vcf(saved_artifact_model_path=getattr(args, constants.M3_MODEL_NAME), + initial_log_variant_prior=getattr(args, constants.INITIAL_LOG_VARIANT_PRIOR_NAME), + initial_log_artifact_prior=getattr(args, constants.INITIAL_LOG_ARTIFACT_PRIOR_NAME), + test_dataset_file=getattr(args, constants.TEST_DATASET_NAME), + contigs_table=getattr(args, constants.CONTIGS_TABLE_NAME), + input_vcf=getattr(args, constants.INPUT_NAME), + output_vcf=getattr(args, constants.OUTPUT_NAME), + batch_size=getattr(args, constants.BATCH_SIZE_NAME), + num_workers=getattr(args, constants.NUM_WORKERS_NAME), + chunk_size=getattr(args, constants.CHUNK_SIZE_NAME), + num_spectrum_iterations=getattr(args, constants.NUM_SPECTRUM_ITERATIONS_NAME), + spectrum_learning_rate=getattr(args, constants.SPECTRUM_LEARNING_RATE_NAME), + tensorboard_dir=getattr(args, constants.TENSORBOARD_DIR_NAME), + genomic_span=getattr(args, constants.GENOMIC_SPAN_NAME), + germline_mode=getattr(args, constants.GERMLINE_MODE_NAME), + no_germline_mode=getattr(args, constants.NO_GERMLINE_MODE_NAME), + het_beta=getattr(args, constants.HET_BETA_NAME), + segmentation=get_segmentation(getattr(args, constants.MAF_SEGMENTS_NAME)), + normal_segmentation=get_segmentation(getattr(args, constants.NORMAL_MAF_SEGMENTS_NAME))) + + +def make_filtered_vcf(saved_artifact_model_path, initial_log_variant_prior: float, initial_log_artifact_prior: float, + test_dataset_file, contigs_table, input_vcf, output_vcf, batch_size: int, num_workers: int, chunk_size: int, num_spectrum_iterations: int, + spectrum_learning_rate: float, tensorboard_dir, genomic_span: int, germline_mode: bool = False, no_germline_mode: bool = False, het_beta: float = None, + segmentation=defaultdict(IntervalTree), normal_segmentation=defaultdict(IntervalTree)): + print("Loading artifact model and test dataset") + contig_index_to_name_map = {} + with open(contigs_table) as file: + while line := file.readline().strip(): + contig, index = line.split() + contig_index_to_name_map[int(index)] = contig + + device = utils.gpu_if_available() + base_model, artifact_model, artifact_log_priors, artifact_spectra_state_dict = \ + load_base_model_and_artifact_model(saved_artifact_model_path, device=device) + + posterior_model = PosteriorModel(initial_log_variant_prior, initial_log_artifact_prior, no_germline_mode=no_germline_mode, num_base_features=artifact_model.num_base_features, het_beta=het_beta) + posterior_data_loader = make_posterior_data_loader(test_dataset_file, input_vcf, contig_index_to_name_map, + base_model, artifact_model, batch_size, num_workers=num_workers, chunk_size=chunk_size, segmentation=segmentation, normal_segmentation=normal_segmentation) + + print("Learning AF spectra") + summary_writer = SummaryWriter(tensorboard_dir) + + num_ignored_sites = genomic_span - len(posterior_data_loader.dataset) + # here is where pretrained artifact priors and spectra are used if given + + posterior_model.learn_priors_and_spectra(posterior_data_loader, num_iterations=num_spectrum_iterations, + summary_writer=summary_writer, ignored_to_non_ignored_ratio=num_ignored_sites/len(posterior_data_loader.dataset), + learning_rate=spectrum_learning_rate) + + print("Calculating optimal logit threshold") + error_probability_thresholds = posterior_model.calculate_probability_thresholds(posterior_data_loader, summary_writer, germline_mode=germline_mode) + print(f"Optimal probability threshold: {error_probability_thresholds}") + apply_filtering_to_vcf(input_vcf, output_vcf, contig_index_to_name_map, error_probability_thresholds, posterior_data_loader, posterior_model, summary_writer=summary_writer, germline_mode=germline_mode) + + +@torch.inference_mode() +def make_posterior_data_loader(dataset_file, input_vcf, contig_index_to_name_map, base_model: BaseModel, artifact_model: ArtifactModel, + batch_size: int, num_workers: int, chunk_size: int, segmentation=defaultdict(IntervalTree), normal_segmentation=defaultdict(IntervalTree)): + print("Reading test dataset") + + m2_filtering_to_keep = set() + allele_frequencies = {} + + print("recording M2 filters and allele frequencies from input VCF") + pbar = tqdm(enumerate(cyvcf2.VCF(input_vcf)), mininterval=60) + for n, v in pbar: + encoding = encode_variant(v, zero_based=True) + if filters_to_keep_from_m2(v): + m2_filtering_to_keep.add(encoding) + allele_frequencies[encoding] = 10 ** (-get_first_numeric_element(v, "POPAF")) + + # pass through the plain text dataset, normalizing and creating ReadSetDatasets as we go, running the artifact model + # to get artifact logits, which we record in a dict keyed by variant strings. These will later be added to PosteriorDatum objects. + print("reading dataset and calculating artifact logits") + print(f"Memory usage percent before loading data: {psutil.virtual_memory().percent:.1f}") + posterior_data = [] + for list_of_base_data in plain_text_data.generate_normalized_data([dataset_file], chunk_size): + print(f"Memory usage percent before creating BaseDataset: {psutil.virtual_memory().percent:.1f}") + raw_dataset = base_dataset.BaseDataset(data_in_ram=list_of_base_data) + print(f"Memory usage percent before creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") + artifact_dataset = ArtifactDataset(raw_dataset, base_model) + print(f"Memory usage percent after creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") + artifact_loader = artifact_dataset.make_data_loader(artifact_dataset.all_folds(), batch_size, pin_memory=torch.cuda.is_available(), num_workers=num_workers) + + print("creating posterior data for this chunk...") + pbar = tqdm(enumerate(artifact_loader), mininterval=60) + for n, artifact_batch_cpu in pbar: + artifact_batch = artifact_batch_cpu.copy_to(device=artifact_model._device, dtype=artifact_model._dtype, non_blocking=artifact_model._device.type == 'cuda') + artifact_logits, _, _ = artifact_model.forward(batch=artifact_batch) + + labels = [(Label.ARTIFACT if label > 0.5 else Label.VARIANT) if is_labeled > 0.5 else Label.UNLABELED for (label, is_labeled) in zip(artifact_batch.get_training_labels(), artifact_batch.get_is_labeled_mask())] + + for variant,counts_and_seq_lks, logit, label, embedding in zip(artifact_batch_cpu.get_variants(), + artifact_batch_cpu.get_counts_and_seq_lks(), + artifact_logits.detach().tolist(), + labels, + artifact_batch.get_representations_2d().cpu()): + contig_name = contig_index_to_name_map[variant.contig] + encoding = encode(contig_name, variant.position, variant.ref, variant.alt) + if encoding in allele_frequencies and encoding not in m2_filtering_to_keep: + allele_frequency = allele_frequencies[encoding] + + # these are default dicts, so if there's no segmentation for the contig we will get no overlaps but not an error + # For a general IntervalTree there is a list of potentially multiple overlaps but here there is either one or zero + segmentation_overlaps = segmentation[contig_name][variant.position] + normal_segmentation_overlaps = normal_segmentation[contig_name][variant.position] + maf = list(segmentation_overlaps)[0].data if segmentation_overlaps else 0.5 + normal_maf = list(normal_segmentation_overlaps)[0].data if normal_segmentation_overlaps else 0.5 + + posterior_datum = PosteriorDatum(variant, counts_and_seq_lks, allele_frequency, logit, embedding, label, maf, normal_maf) + posterior_data.append(posterior_datum) + + print(f"Size of filtering dataset: {len(posterior_data)}") + posterior_dataset = PosteriorDataset(posterior_data) + print(f"Memory usage percent after creating PosteriorDataset: {psutil.virtual_memory().percent:.1f}") + return posterior_dataset.make_data_loader(batch_size, pin_memory=torch.cuda.is_available(), num_workers=num_workers) + + +# error probability thresholds is a dict from Variant type to error probability threshold (float) +@torch.inference_mode() +def apply_filtering_to_vcf(input_vcf, output_vcf, contig_index_to_name_map, error_probability_thresholds, + posterior_loader, posterior_model, summary_writer: SummaryWriter, germline_mode: bool = False): + print("Computing final error probabilities") + passing_call_type = Call.GERMLINE if germline_mode else Call.SOMATIC + encoding_to_posterior_results = {} + + pbar = tqdm(enumerate(posterior_loader), mininterval=60) + batch_cpu: PosteriorBatch + for n, batch_cpu in pbar: + batch = batch_cpu.copy_to(device=posterior_model._device, dtype=posterior_model._dtype, non_blocking=posterior_model._device.type == 'cuda') + # posterior, along with intermediate tensors for debugging/interpretation + log_priors, spectra_lls, normal_lls, log_posteriors = \ + posterior_model.log_posterior_and_ingredients(batch) + + posterior_probs = torch.nn.functional.softmax(log_posteriors, dim=1) + + encodings = [encode_datum(variant, contig_index_to_name_map) for variant in batch.get_variants()] + artifact_logits = batch.get_artifact_logits().tolist() + var_types = batch.get_variant_types().tolist() + labels = batch.get_labels().tolist() + alt_counts = batch.get_alt_counts().tolist() + depths = batch.get_depths().tolist() + + for encoding, post_probs, logit, log_prior, log_spec, log_normal, label, alt_count, depth, var_type, embedding in zip(encodings, posterior_probs, artifact_logits, log_priors, spectra_lls, normal_lls, labels, alt_counts, depths, var_types, batch.embeddings): + encoding_to_posterior_results[encoding] = PosteriorResult(logit, post_probs.tolist(), log_prior, log_spec, log_normal, label, alt_count, depth, var_type, embedding) + + print("Applying threshold") + unfiltered_vcf = cyvcf2.VCF(input_vcf) + + all_types = [call_type.name for call_type in Call] + unfiltered_vcf.add_format_to_header( {'ID': "DP", 'Description': "depth", 'Type': 'Integer', 'Number': '1'}) + unfiltered_vcf.add_info_to_header({'ID': POST_PROB_INFO_KEY, 'Description': 'Mutect3 posterior probability of {' + ', '.join(all_types) + '}', + 'Type': 'Float', 'Number': 'A'}) + unfiltered_vcf.add_info_to_header({'ID': LOG_PRIOR_INFO_KEY, 'Description': 'Log priors of {' + ', '.join(all_types) + '}', + 'Type': 'Float', 'Number': 'A'}) + unfiltered_vcf.add_info_to_header({'ID': SPECTRA_LOG_LIKELIHOOD_INFO_KEY, 'Description': 'Log spectra likelihoods of {' + ', '.join(all_types) + '}', + 'Type': 'Float', 'Number': 'A'}) + unfiltered_vcf.add_info_to_header({'ID': NORMAL_LOG_LIKELIHOOD_INFO_KEY, 'Description': 'Log normal likelihoods of {' + ', '.join(all_types) + '}', + 'Type': 'Float', 'Number': 'A'}) + unfiltered_vcf.add_info_to_header({'ID': ARTIFACT_LOD_INFO_KEY, 'Description': 'Mutect3 artifact log odds', + 'Type': 'Float', 'Number': 'A'}) + + for n, filter_name in enumerate(FILTER_NAMES): + if n != passing_call_type: + unfiltered_vcf.add_filter_to_header({'ID': filter_name, 'Description': filter_name}) + + writer = cyvcf2.Writer(output_vcf, unfiltered_vcf) # input vcf is a template for the header + evaluation_metrics = EvaluationMetrics() + pbar = tqdm(enumerate(unfiltered_vcf), mininterval=60) + labeled_truth = False + embedding_metrics = EmbeddingMetrics() # only if there is labeled truth for evaluation + + missing_encodings = [] + for n, v in pbar: + filters = filters_to_keep_from_m2(v) + + # TODO: in germline mode, somatic doesn't exist (or is just highly irrelevant) and germline is not an error! + encoding = encode_variant(v, zero_based=True) # cyvcf2 is zero-based + if encoding in encoding_to_posterior_results: + posterior_result = encoding_to_posterior_results[encoding] + post_probs = posterior_result.posterior_probabilities + v.INFO[POST_PROB_INFO_KEY] = ','.join(map(lambda prob: "{:.3f}".format(prob), post_probs)) + v.INFO[LOG_PRIOR_INFO_KEY] = ','.join(map(lambda pri: "{:.3f}".format(pri), posterior_result.log_priors)) + v.INFO[SPECTRA_LOG_LIKELIHOOD_INFO_KEY] = ','.join(map(lambda ll: "{:.3f}".format(ll), posterior_result.spectra_lls)) + v.INFO[ARTIFACT_LOD_INFO_KEY] = "{:.3f}".format(posterior_result.artifact_logit) + v.INFO[NORMAL_LOG_LIKELIHOOD_INFO_KEY] = ','.join(map(lambda ll: "{:.3f}".format(ll), posterior_result.normal_lls)) + + label = Label(posterior_result.label) # this is the Label enum, might be UNLABELED + error_prob = 1 - post_probs[passing_call_type] + variant_type = find_variant_type(v) + called_as_error = error_prob > error_probability_thresholds[variant_type] + + error_call = None + + if called_as_error: + # get the error type with the largest posterior probability + highest_prob_indices = torch.topk(torch.Tensor(post_probs), 2).indices.tolist() + highest_prob_index = highest_prob_indices[1] if highest_prob_indices[0] == passing_call_type else highest_prob_indices[0] + error_call = list(Call)[highest_prob_index] + filters.add(FILTER_NAMES[highest_prob_index]) + + # note that this excludes the correctness part of embedding metrics, which is below + embedding_metrics.label_metadata.append(label.name) + embedding_metrics.type_metadata.append(variant_type.name) + embedding_metrics.truncated_count_metadata.append( + str(round_up_to_nearest_three(min(MAX_COUNT, posterior_result.alt_count)))) + embedding_metrics.representations.append(posterior_result.embedding) + + correctness_label = "unknown" + if label != Label.UNLABELED: + labeled_truth = True + clipped_error_prob = 0.5 + 0.9999999 * (error_prob - 0.5) + error_logit = prob_to_logit(clipped_error_prob) + float_label = 1.0 if label == Label.ARTIFACT else 0.0 + + # TODO: this is sloppy -- it only works because when we label the posterior dataset (if truth is available) + # TODO: we stretch the definitions so that "Label.ARTIFACT" simply means "something we shouldn't call", including + # TODO: artifact or germline (in the somatic calling case), and "Label.VARIANT" means "something we should call" + is_correct = (called_as_error and label == Label.ARTIFACT) or (not called_as_error and label == Label.VARIANT) + evaluation_metrics.record_call(Epoch.TEST, variant_type, error_logit, float_label, is_correct, posterior_result.alt_count) + + # TODO: double-check the logic here + if is_correct: + if label == Label.VARIANT: + correctness_label = EmbeddingMetrics.TRUE_POSITIVE + elif error_call == Call.ARTIFACT or error_call == Call.NORMAL_ARTIFACT: + correctness_label = EmbeddingMetrics.TRUE_NEGATIVE_ARTIFACT + #elif error_call == Call.SEQ_ERROR: + # correctness_label = EmbeddingMetrics.TRUE_NEGATIVE_SEQ_ERROR + # we don't do anything for germline (in somatic mode) or seq error -- + else: + if called_as_error: + if error_call == Call.ARTIFACT or error_call == Call.NORMAL_ARTIFACT: + correctness_label = EmbeddingMetrics.FALSE_NEGATIVE_ARTIFACT + else: + correctness_label = EmbeddingMetrics.FALSE_POSITIVE + # TODO: this is only right for somatic calling + bad_call = error_call if called_as_error else Call.SOMATIC + evaluation_metrics.record_mistake(posterior_result, bad_call) + embedding_metrics.correct_metadata.append(correctness_label) + else: + missing_encodings.append(encoding) + v.FILTER = ';'.join(filters) if filters else 'PASS' + writer.write_record(v) + print("closing resources") + writer.close() + unfiltered_vcf.close() + + embedding_metrics.output_to_summary_writer(summary_writer, is_filter_variants=True) + + if labeled_truth: + given_thresholds = {var_type: prob_to_logit(error_probability_thresholds[var_type]) for var_type in Variation} + evaluation_metrics.make_plots(summary_writer, given_thresholds, sens_prec=True) + evaluation_metrics.make_mistake_histograms(summary_writer) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/tools/preprocess_dataset.py b/src/main/python/org/broadinstitute/hellbender/permutect/tools/preprocess_dataset.py new file mode 100644 index 00000000000..1eda5e64392 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/tools/preprocess_dataset.py @@ -0,0 +1,67 @@ +import argparse +import os +import tarfile +import tempfile +from typing import List + +from permutect import constants +from permutect.data import base_datum +from permutect.data.plain_text_data import generate_normalized_data +from permutect.utils import ConsistentValue + +""" +This tool takes as input a list of text file Mutect3 training datasets, reads them in chunks that fit in memory, +normalizes each chunk, outputs each chunk as a binary PyTorch file, and bundles the output as a tarfile. +""" + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='preprocess plain text training dataset into tarfile of nprmalized binary data') + parser.add_argument('--' + constants.TRAINING_DATASETS_NAME, nargs='+', type=str, required=True, + help='list of plain text data files') + parser.add_argument('--' + constants.CHUNK_SIZE_NAME, type=int, default=int(2e9), required=False, + help='size in bytes of output binary data files') + parser.add_argument('--' + constants.SOURCES_NAME, nargs='+', type=int, required=False, + help='integer sources corresponding to plain text data files for distinguishing different sequencing conditions') + parser.add_argument('--' + constants.OUTPUT_NAME, type=str, default=None, required=True, + help='path to output tarfile of training data') + return parser.parse_args() + + +def do_work(training_datasets, training_output_file, chunk_size, sources: List[int]): + data_files = [] + num_read_features, num_info_features, ref_sequence_length = ConsistentValue(), ConsistentValue(), ConsistentValue() + + # save all the lists of read sets to tempfiles. . . + # TODO: left off here. Need to give it sources, which will need to be command line argument + for base_data_list in generate_normalized_data(training_datasets, max_bytes_per_chunk=chunk_size, sources=sources): + num_read_features.check(base_data_list[0].get_reads_2d().shape[1]) + num_info_features.check(base_data_list[0].get_info_tensor_1d().shape[0]) + ref_sequence_length.check(base_data_list[0].get_ref_sequence_1d().shape[0]) + + with tempfile.NamedTemporaryFile(delete=False) as train_data_file: + base_datum.save_list_base_data(base_data_list, train_data_file) + data_files.append(train_data_file.name) + + # . . . and bundle them in a tarfile + with tarfile.open(training_output_file, "w") as train_tar: + for train_file in data_files: + train_tar.add(train_file, arcname=os.path.basename(train_file)) + + +def main_without_parsing(args): + chunk_size = getattr(args, constants.CHUNK_SIZE_NAME) + training_datasets = getattr(args, constants.TRAINING_DATASETS_NAME) + output_file = getattr(args, constants.OUTPUT_NAME) + sources = getattr(args, constants.SOURCES_NAME) + + do_work(training_datasets, output_file, chunk_size, sources) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/tools/prune_dataset.py b/src/main/python/org/broadinstitute/hellbender/permutect/tools/prune_dataset.py new file mode 100644 index 00000000000..fd6d1d292f9 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/tools/prune_dataset.py @@ -0,0 +1,249 @@ +import argparse +import os +import tarfile +import tempfile +from typing import List + +import psutil + +from permutect.architecture.base_model import load_base_model, BaseModel +from permutect.data import base_datum +from tqdm.autonotebook import tqdm + +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from permutect import constants, utils +from permutect.architecture.artifact_model import ArtifactModel +from permutect.data.base_datum import ArtifactDatum, ArtifactBatch +from permutect.data.artifact_dataset import ArtifactDataset +from permutect.parameters import ArtifactModelParameters, parse_artifact_model_params, \ + add_artifact_model_params_to_parser, add_training_params_to_parser +from permutect.data.base_dataset import BaseDataset +from permutect.tools.train_model import TrainingParameters, parse_training_params +from permutect.utils import Label + +NUM_FOLDS = 3 + + +# labeled only pruning loader must be constructed with options to emit batches of all-labeled data +def calculate_pruning_thresholds(labeled_only_pruning_loader, artifact_model: ArtifactModel, label_art_frac: float, training_params: TrainingParameters) -> List[int]: + for fold in range(NUM_FOLDS): + average_artifact_confidence, average_nonartifact_confidence = utils.StreamingAverage(), utils.StreamingAverage() + # TODO: eventually this should all be segregated by variant type and maybe also alt count + + # the 0th/1st element is a list of predicted probabilities that data labeled as non-artifact/artifact are actually non-artifact/artifact + probs_of_agreeing_with_label = [[],[]] + print("calculating average confidence and gathering predicted probabilities") + pbar = tqdm(enumerate(labeled_only_pruning_loader), mininterval=60) + for n, batch in pbar: + # TODO: should we use likelihoods as in evaluation or posteriors as in training??? + # TODO: does it even matter?? + art_logits, _, _ = artifact_model.forward(batch) + art_probs = torch.sigmoid(art_logits.detach()) + + labels = batch.get_training_labels() + art_label_mask = (labels > 0.5) + nonart_label_mask = (labels < 0.5) + average_artifact_confidence.record_with_mask(art_probs, art_label_mask) + average_nonartifact_confidence.record_with_mask(1 - art_probs, nonart_label_mask) + + for art_prob, labeled_as_art in zip(art_probs.tolist(), art_label_mask.tolist()): + agreement_prob = art_prob if labeled_as_art else (1 - art_prob) + probs_of_agreeing_with_label[1 if labeled_as_art else 0].append(agreement_prob) + + # TODO: it is wasteful to run forward passes on all the data again when we can just record indices and logits + print("estimating error rates") + # The i,j element is the count of data labeled as i that pass the confidence threshold for j + # here 0 means non-artifact and 1 means artifact + confusion = [[0, 0], [0, 0]] + art_conf_threshold = average_artifact_confidence.get() + nonart_conf_threshold = average_nonartifact_confidence.get() + pbar = tqdm(enumerate(labeled_only_pruning_loader), mininterval=60) + for n, batch in pbar: + predicted_artifact_logits, _, _ = artifact_model.forward(batch) + predicted_artifact_probs = torch.sigmoid(predicted_artifact_logits.detach()) + + conf_art_mask = predicted_artifact_probs >= art_conf_threshold + conf_nonart_mask = (1 - predicted_artifact_probs) >= nonart_conf_threshold + art_label_mask = (batch.get_training_labels() > 0.5) + + for conf_artifact, conf_nonartifact, artifact_label in zip(conf_art_mask.tolist(), conf_nonart_mask.tolist(), art_label_mask.tolist()): + row = 1 if artifact_label else 0 + if conf_artifact: + confusion[row][1] += 1 + if conf_nonartifact: + confusion[row][0] += 1 + + # these are the probabilities of a true (hidden label) artifact/non-artifact being mislabeled as non-artifact/artifact + art_error_rate = confusion[0][1] / (confusion[0][1] + confusion[1][1]) + nonart_error_rate = confusion[1][0] / (confusion[0][0] + confusion[1][0]) + + # fraction of labeled data that are labeled as artifact + label_nonart_frac = 1 - label_art_frac + + # these are the inverse probabilities that something labeled as artifact/non-artifact was actually a mislabeled nonartifact/artifact + inv_art_error_rate = (nonart_error_rate / label_art_frac) * (label_nonart_frac - art_error_rate) / (1 - art_error_rate - nonart_error_rate) + inv_nonart_error_rate = (art_error_rate / label_nonart_frac) * (label_art_frac - nonart_error_rate) / (1 - art_error_rate - nonart_error_rate) + + print("Estimated error rates: ") + print(f"artifact mislabeled as non-artifact: {art_error_rate:.3f}") + print(f"non-artifact mislabeled as artifact: {nonart_error_rate:.3f}") + + print("Estimated inverse error rates: ") + print(f"Labeled artifact was actually non-artifact: {inv_art_error_rate:.3f}") + print(f"Labeled non-artifact was actually artifact: {inv_nonart_error_rate:.3f}") + + print("calculating rank pruning thresholds") + nonart_threshold = torch.quantile(torch.Tensor(probs_of_agreeing_with_label[0]), inv_nonart_error_rate).item() + art_threshold = torch.quantile(torch.Tensor(probs_of_agreeing_with_label[1]), inv_art_error_rate).item() + + print("Rank pruning thresholds: ") + print(f"Labeled artifacts are pruned if predicted artifact probability is less than {art_threshold:.3f}") + print(f"Labeled non-artifacts are pruned if predicted non-artifact probability is less than {nonart_threshold:.3f}") + + return art_threshold, nonart_threshold + + +# generates BaseDatum(s) from the original dataset that *pass* the pruning thresholds +def generated_pruned_data_for_fold(art_threshold: float, nonart_threshold: float, pruning_base_data_loader, + base_model: BaseModel, artifact_model: ArtifactModel) -> List[int]: + print("pruning the dataset") + pbar = tqdm(enumerate(pruning_base_data_loader), mininterval=60) + for n, base_batch in pbar: + # apply the representation model AND the artifact model to go from the original read set to artifact logits + representation, _ = base_model.calculate_representations(base_batch) + + artifact_batch = ArtifactBatch([ArtifactDatum(rs, rep) for rs, rep in zip(base_batch.original_list(), representation.detach())]) + + art_logits, _, _ = artifact_model.forward(artifact_batch) + art_probs = torch.sigmoid(art_logits.detach()) + art_label_mask = (base_batch.get_training_labels() > 0.5) + is_labeled_mask = (base_batch.get_is_labeled_mask() > 0.5) + + for art_prob, labeled_as_art, datum, is_labeled in zip(art_probs.tolist(), art_label_mask.tolist(), base_batch.original_list(), is_labeled_mask.tolist()): + if not is_labeled: + yield datum + elif (labeled_as_art and art_prob < art_threshold) or ((not labeled_as_art) and (1-art_prob) < nonart_threshold): + # TODO: process failing data, perhaps add option to output a pruned dataset? or flip labels? + pass + else: + yield datum # this is a ReadSet + + +def generate_pruned_data_for_all_folds(base_dataset: BaseDataset, base_model: BaseModel, + training_params: TrainingParameters, params: ArtifactModelParameters, tensorboard_dir): + # for each fold in turn, train an artifact model on all other folds and prune the chosen fold + use_gpu = torch.cuda.is_available() + device = torch.device('cuda' if use_gpu else 'cpu') + + for pruning_fold in range(NUM_FOLDS): + summary_writer = SummaryWriter(tensorboard_dir + "/fold_" + str(pruning_fold)) + print(f"Pruning data from fold {pruning_fold} of {NUM_FOLDS}") + print(f"Memory usage percent: {psutil.virtual_memory().percent:.3f}") + + # learn an artifact model with the pruning data held out + artifact_dataset = ArtifactDataset(base_dataset, base_model, base_dataset.all_but_one_fold(pruning_fold)) + + # sum is over variant types + label_art_frac = np.sum(artifact_dataset.totals[-1][Label.ARTIFACT]) / np.sum(artifact_dataset.totals[-1][Label.ARTIFACT] + + artifact_dataset.totals[-1][Label.VARIANT]) + + # learn pruning thresholds on the held-out data + pruning_artifact_dataset = ArtifactDataset(base_dataset, base_model, [pruning_fold]) + labeled_only_pruning_loader = pruning_artifact_dataset.make_data_loader(pruning_artifact_dataset.all_folds(), + training_params.batch_size, use_gpu, training_params.num_workers, labeled_only=True) + model = ArtifactModel(params=params, num_base_features=artifact_dataset.num_base_features, num_ref_alt_features=base_model.ref_alt_seq_embedding_dimension(), device=device).float() + model.learn(artifact_dataset, training_params, summary_writer=summary_writer) + + # TODO: maybe this should be done by variant type and/or count + art_threshold, nonart_threshold = calculate_pruning_thresholds(labeled_only_pruning_loader, model, label_art_frac, training_params) + + # unlike when learning thresholds, we load labeled and unlabeled data here + pruning_base_data_loader = base_dataset.make_data_loader([pruning_fold], training_params.batch_size, use_gpu, training_params.num_epochs) + for passing_base_datum in generated_pruned_data_for_fold(art_threshold, nonart_threshold, pruning_base_data_loader, base_model, model): + yield passing_base_datum + + +# takes a ReadSet generator and organies into buffers. +# TODO: probably code duplication since the generator is already pruned +def generate_pruned_data_buffers(pruned_data_generator, max_bytes_per_chunk: int): + buffer, bytes_in_buffer = [], 0 + for datum in pruned_data_generator: + + buffer.append(datum) + bytes_in_buffer += datum.size_in_bytes() + if bytes_in_buffer > max_bytes_per_chunk: + print(f"Memory usage percent: {psutil.virtual_memory().percent:.1f}") + print(f"{bytes_in_buffer} bytes in chunk") + yield buffer + buffer, bytes_in_buffer = [], 0 + + # There will be some data left over, in general. + if buffer: + yield buffer + + +def make_pruned_training_dataset(pruned_data_buffer_generator, pruned_tarfile): + pruned_data_files = [] + for base_data_list in pruned_data_buffer_generator: + with tempfile.NamedTemporaryFile(delete=False) as train_data_file: + base_datum.save_list_base_data(base_data_list, train_data_file) + pruned_data_files.append(train_data_file.name) + + # bundle them in a tarfile + with tarfile.open(pruned_tarfile, "w") as train_tar: + for train_file in pruned_data_files: + train_tar.add(train_file, arcname=os.path.basename(train_file)) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='train the Mutect3 artifact model') + + add_artifact_model_params_to_parser(parser) + add_training_params_to_parser(parser) + + parser.add_argument('--' + constants.CHUNK_SIZE_NAME, type=int, default=int(2e9), required=False, + help='size in bytes of output binary data files') + + # input / output + parser.add_argument('--' + constants.TRAIN_TAR_NAME, type=str, required=True, + help='tarfile of training/validation datasets produced by preprocess_dataset.py') + parser.add_argument('--' + constants.BASE_MODEL_NAME, type=str, help='Base model from train_base_model.py') + parser.add_argument('--' + constants.OUTPUT_NAME, type=str, required=True, help='path to pruned dataset file') + parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, + help='path to output tensorboard directory') + + return parser.parse_args() + + +def main_without_parsing(args): + params = parse_artifact_model_params(args) + training_params = parse_training_params(args) + + tensorboard_dir = getattr(args, constants.TENSORBOARD_DIR_NAME) + pruned_tarfile = getattr(args, constants.OUTPUT_NAME) + chunk_size = getattr(args, constants.CHUNK_SIZE_NAME) + original_tarfile = getattr(args, constants.TRAIN_TAR_NAME) + + base_model = load_base_model(getattr(args, constants.BASE_MODEL_NAME)) + base_dataset = BaseDataset(data_tarfile=original_tarfile, num_folds=NUM_FOLDS) + + # generate ReadSets passing pruning + pruned_data_generator = generate_pruned_data_for_all_folds(base_dataset, base_model, training_params, params, tensorboard_dir) + + # generate List[ReadSet]s passing pruning + pruned_data_buffer_generator = generate_pruned_data_buffers(pruned_data_generator, chunk_size) + + # save as a tarfile dataset + make_pruned_training_dataset(pruned_data_buffer_generator, pruned_tarfile=pruned_tarfile) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/tools/train_base_model.py b/src/main/python/org/broadinstitute/hellbender/permutect/tools/train_base_model.py new file mode 100644 index 00000000000..4ba3e2bcb40 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/tools/train_base_model.py @@ -0,0 +1,60 @@ +import argparse + +from torch.utils.tensorboard import SummaryWriter + +from permutect import constants, utils +from permutect.architecture.base_model import BaseModel, LearningMethod, load_base_model, learn_base_model +from permutect.parameters import BaseModelParameters, TrainingParameters, parse_training_params, \ + parse_base_model_params, add_base_model_params_to_parser, add_training_params_to_parser +from permutect.data.base_dataset import BaseDataset + + +def train_base_model(params: BaseModelParameters, training_params: TrainingParameters, learning_method: LearningMethod, + summary_writer: SummaryWriter, dataset: BaseDataset, pretrained_model: BaseModel = None) -> BaseModel: + base_model = pretrained_model if (pretrained_model is not None) else \ + BaseModel(params=params, num_read_features=dataset.num_read_features, num_info_features=dataset.num_info_features, + ref_sequence_length=dataset.ref_sequence_length, device=utils.gpu_if_available()) + learn_base_model(base_model, dataset, learning_method, training_params, summary_writer=summary_writer) + return base_model + + +def main_without_parsing(args): + params = parse_base_model_params(args) + training_params = parse_training_params(args) + + learning_method = LearningMethod[getattr(args, constants.LEARNING_METHOD_NAME)] + + tarfile_data = getattr(args, constants.TRAIN_TAR_NAME) + pretrained_model_path = getattr(args, constants.PRETRAINED_MODEL_NAME) + pretrained_model = None if pretrained_model_path is None else load_base_model(pretrained_model_path) + tensorboard_dir = getattr(args, constants.TENSORBOARD_DIR_NAME) + summary_writer = SummaryWriter(tensorboard_dir) + dataset = BaseDataset(data_tarfile=tarfile_data, num_folds=10) + + model = train_base_model(params=params, dataset=dataset, training_params=training_params, learning_method=learning_method, + summary_writer=summary_writer, pretrained_model=pretrained_model) + + summary_writer.close() + model.save(getattr(args, constants.OUTPUT_NAME)) + +def parse_arguments(): + parser = argparse.ArgumentParser(description='train the Permutect read set representation model') + add_base_model_params_to_parser(parser) + add_training_params_to_parser(parser) + + parser.add_argument('--' + constants.LEARNING_METHOD_NAME, type=str, required=False, default='SEMISUPERVISED') + parser.add_argument('--' + constants.TRAIN_TAR_NAME, type=str, required=True, + help='tarfile of training/validation datasets produced by preprocess_dataset.py') + parser.add_argument('--' + constants.OUTPUT_NAME, type=str, required=True, help='output saved model file') + parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, + help='output tensorboard directory') + + return parser.parse_args() + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/tools/train_model.py b/src/main/python/org/broadinstitute/hellbender/permutect/tools/train_model.py new file mode 100644 index 00000000000..2eaa1ebc0bc --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/tools/train_model.py @@ -0,0 +1,125 @@ +import argparse + +import psutil +import torch +from torch.utils.tensorboard import SummaryWriter + +from permutect import constants, utils +from permutect.architecture.artifact_model import ArtifactModel +from permutect.architecture.artifact_spectra import ArtifactSpectra +from permutect.architecture.posterior_model import plot_artifact_spectra +from permutect.architecture.base_model import load_base_model +from permutect.data.base_dataset import BaseDataset +from permutect.data.artifact_dataset import ArtifactDataset +from permutect.data.base_datum import ArtifactDatum +from permutect.parameters import TrainingParameters, add_training_params_to_parser, parse_training_params, \ + ArtifactModelParameters, parse_artifact_model_params, add_artifact_model_params_to_parser +from permutect.utils import Variation, Label + + +def train_artifact_model(hyperparams: ArtifactModelParameters, training_params: TrainingParameters, summary_writer: SummaryWriter, dataset: ArtifactDataset): + model = ArtifactModel(params=hyperparams, num_base_features=dataset.num_base_features, num_ref_alt_features=dataset.num_ref_alt_features, device=utils.gpu_if_available()) + # TODO: magic constant + model.learn(dataset, training_params, summary_writer=summary_writer, epochs_per_evaluation=10) + + for n, var_type in enumerate(Variation): + cal_fig, cal_axes = model.calibration[n].plot_calibration() + summary_writer.add_figure("calibration by count for " + var_type.name, cal_fig) + + return model + + +def learn_artifact_priors_and_spectra(artifact_dataset: ArtifactDataset, genomic_span_of_data: int): + artifact_counts = torch.zeros(len(utils.Variation)) + types_list, depths_list, alt_counts_list = [], [], [] + + artifact_datum: ArtifactDatum + for artifact_datum in artifact_dataset: + if artifact_datum.get_label() != Label.ARTIFACT: + continue + variant_type = artifact_datum.get_variant_type() + artifact_counts[variant_type] += 1 + types_list.append(variant_type) + counts_and_seq_lks = artifact_datum.one_dimensional_data.get_counts_and_seq_lks() + depths_list.append(counts_and_seq_lks.depth) + alt_counts_list.append(counts_and_seq_lks.alt_count) + + # turn the lists into tensors + types_tensor = torch.LongTensor(types_list) + depths_tensor = torch.Tensor(depths_list).float() + alt_counts_tensor = torch.Tensor(alt_counts_list).float() + + log_artifact_priors = torch.log(artifact_counts / genomic_span_of_data) + artifact_spectra = ArtifactSpectra(num_components=2) + + # TODO: hard-coded num epochs!!! + artifact_spectra.fit(num_epochs=10, types_b=types_tensor, depths_1d_tensor=depths_tensor, + alt_counts_1d_tensor=alt_counts_tensor, batch_size=64) + + return log_artifact_priors, artifact_spectra + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='train the Permutect artifact model') + + add_artifact_model_params_to_parser(parser) + add_training_params_to_parser(parser) + + parser.add_argument('--' + constants.LEARN_ARTIFACT_SPECTRA_NAME, action='store_true', + help='flag to include artifact priors and allele fraction spectra in saved output. ' + 'This is worth doing if labeled training data is available but might work poorly ' + 'when Mutect3 generates weak labels based on allele fractions.') + parser.add_argument('--' + constants.GENOMIC_SPAN_NAME, type=float, required=False, + help='Total number of sites considered by Mutect2 in all training data, including those lacking variation or artifacts, hence absent from input datasets. ' + 'Necessary for learning priors since otherwise rates of artifacts and variants would be overinflated. ' + 'Only required if learning artifact log priors') + + # inputs and outputs + parser.add_argument('--' + constants.TRAIN_TAR_NAME, type=str, required=True, + help='tarfile of training/validation datasets produced by preprocess_dataset.py') + parser.add_argument('--' + constants.BASE_MODEL_NAME, type=str, help='Base model from train_base_model.py') + parser.add_argument('--' + constants.OUTPUT_NAME, type=str, required=True, help='path to output saved model file') + parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, + help='path to output tensorboard directory') + + return parser.parse_args() + + +def main_without_parsing(args): + params = parse_artifact_model_params(args) + training_params = parse_training_params(args) + learn_artifact_spectra = getattr(args, constants.LEARN_ARTIFACT_SPECTRA_NAME) + genomic_span = getattr(args, constants.GENOMIC_SPAN_NAME) + + tensorboard_dir = getattr(args, constants.TENSORBOARD_DIR_NAME) + summary_writer = SummaryWriter(tensorboard_dir) + + base_model = load_base_model(getattr(args, constants.BASE_MODEL_NAME)) + print(f"Memory usage percent before creating BaseDataset: {psutil.virtual_memory().percent:.1f}") + base_dataset = BaseDataset(data_tarfile=getattr(args, constants.TRAIN_TAR_NAME), num_folds=10) + print(f"Memory usage percent before creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") + artifact_dataset = ArtifactDataset(base_dataset, + base_model, + base_loader_num_workers=training_params.num_workers, + base_loader_batch_size=training_params.inference_batch_size) + print(f"Memory usage percent after creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") + + model = train_artifact_model(hyperparams=params, training_params=training_params, summary_writer=summary_writer, dataset=artifact_dataset) + print(f"Memory usage percent after training artifact model: {psutil.virtual_memory().percent:.1f}") + + artifact_log_priors, artifact_spectra = learn_artifact_priors_and_spectra(artifact_dataset, genomic_span) if learn_artifact_spectra else (None, None) + if artifact_spectra is not None: + art_spectra_fig, art_spectra_axs = plot_artifact_spectra(artifact_spectra, depth=50) + summary_writer.add_figure("Artifact AF Spectra", art_spectra_fig) + + summary_writer.close() + model.save_with_base_model(base_model, getattr(args, constants.OUTPUT_NAME), artifact_log_priors, artifact_spectra) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() diff --git a/src/main/python/org/broadinstitute/hellbender/permutect/utils.py b/src/main/python/org/broadinstitute/hellbender/permutect/utils.py new file mode 100644 index 00000000000..7d25fce8393 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/permutect/utils.py @@ -0,0 +1,292 @@ +import enum +import numpy as np +import cyvcf2 +import tarfile +import os +import torch + + +class ConsistentValue: + """ + Tracks a value that once initialized, is consistent among eg all members of a dataset. For example, all tensors + must have the same number of columns. + """ + def __init__(self, value=None): + self.value = value + + def check(self, value): + if self.value is None: + self.value = value + else: + assert self.value == value, "inconsistent values" + + +class MutableInt: + def __init__(self, value:int = 0): + self.value = value + + def __str__(self): + return str(self.value) + + def increment(self, amount: int = 1): + self.value += amount + + def decrement(self, amount: int = 1): + self.value -= amount + + def get_and_then_increment(self): + self.value += 1 + return self.value - 1 + + def get(self): + return self.value + + def set(self, value: int): + self.value = value + + +def gpu_if_available(exploit_mps=False) -> torch.device: + if torch.cuda.is_available(): + d = 'cuda' + elif exploit_mps and torch.mps.is_available(): + d = 'mps' + else: + d = 'cpu' + return torch.device(d) + + +def downsample_tensor(tensor2d: np.ndarray, new_length: int): + if tensor2d is None or new_length >= len(tensor2d): + return tensor2d + perm = np.random.permutation(len(tensor2d)) + return tensor2d[perm[:new_length]] + + +def get_variant_type(alt_allele, ref_allele): + variant_size = len(alt_allele) - len(ref_allele) + if variant_size == 0: + return Variation.SNV + else: + return Variation.INSERTION if variant_size > 0 else Variation.DELETION + + +class Variation(enum.IntEnum): + SNV = 0 + INSERTION = 1 + DELETION = 2 + BIG_INSERTION = 3 + BIG_DELETION = 4 + + @staticmethod + def get_type(ref_allele: str, alt_allele: str): + diff = len(alt_allele) - len(ref_allele) + if diff == 0: + return Variation.SNV + elif diff > 0: + return Variation.BIG_INSERTION if diff > 1 else Variation.INSERTION + else: + return Variation.BIG_DELETION if diff < -1 else Variation.DELETION + + +class Call(enum.IntEnum): + SOMATIC = 0 + ARTIFACT = 1 + SEQ_ERROR = 2 + GERMLINE = 3 + NORMAL_ARTIFACT = 4 + + +class Epoch(enum.IntEnum): + TRAIN = 0 + VALID = 1 + TEST = 2 + + +class Label(enum.IntEnum): + ARTIFACT = 0 + VARIANT = 1 + UNLABELED = 2 + + @staticmethod + def get_label(label_str: str): + for label in Label: + if label_str == label.name: + return label + + raise ValueError('label is invalid: %s' % label_str) + + @staticmethod + def is_label(label_str: str): + for label in Label: + if label_str == label.name: + return True + + return False + + +def freeze(parameters): + for parameter in parameters: + parameter.requires_grad = False + + +def unfreeze(parameters): + for parameter in parameters: + if parameter.dtype.is_floating_point: # an integer parameter isn't trainable by gradient descent + parameter.requires_grad = True + + +def f_score(tp, fp, total_true): + fn = total_true - tp + return tp / (tp + (fp + fn) / 2) + + +# note: this function works for n, k, alpha, beta tensors of the same shape +# the result is computed element-wise ie result[i,j. . .] = beta_binomial(n[i,j..], k[i,j..], alpha[i,j..], beta[i,j..) +# often n, k will correspond to a batch dimension and alpha, beta correspond to a model, in which case +# unsqueezing is necessary +# NOTE: this includes the nCk factor +def beta_binomial(n, k, alpha, beta): + combinatorial_term = torch.lgamma(n + 1) - torch.lgamma(n - k + 1) - torch.lgamma(k + 1) + return combinatorial_term + torch.lgamma(k + alpha) + torch.lgamma(n - k + beta) + torch.lgamma(alpha + beta) \ + - torch.lgamma(n + alpha + beta) - torch.lgamma(alpha) - torch.lgamma(beta) + + +# note: this function works for n, k, p tensors of the same shape +# the result is computed element-wise ie result[i,j. . .] = binomial(n[i,j..], k[i,j..], p[i,j..]) +# often n, k will correspond to a batch dimension and p correspond to a model, in which case +# unsqueezing is necessary +# NOTE: this includes the nCk factor +def binomial(n, k, p): + combinatorial_term = torch.lgamma(n + 1) - torch.lgamma(n - k + 1) - torch.lgamma(k + 1) + return combinatorial_term + k * torch.log(p) + (n - k) * torch.log(1 - p) + + +# note: this function works for n, k, alpha, beta tensors of the same shape +# the result is computed element-wise ie result[i,j. . .] = gamma_binomial(n[i,j..], k[i,j..], alpha[i,j..], beta[i,j..) +# often n, k will correspond to a batch dimension and alpha, beta correspond to a model, in which case +# unsqueezing is necessary +# NOTE: this includes the nCk factor +# WARNING: the approximations here only work if Gamma(f|alpha, beta) has very little density for f > 1 +# see pp 2 - 4 of my notebook +def gamma_binomial(n, k, alpha, beta): + alpha_tilde = (k + 1) * (n + 2) / (n - k + 1) + beta_tilde = (n + 1) * (n + 2) / (n - k + 1) + + exponent_term = alpha_tilde * torch.log(beta_tilde) + alpha * torch.log(beta) -\ + (alpha + alpha_tilde - 1) * torch.log(beta + beta_tilde) + gamma_term = torch.lgamma(alpha + alpha_tilde - 1) - torch.lgamma(alpha) - torch.lgamma(alpha_tilde) + return exponent_term + gamma_term - torch.log(n + 1) + + +# for tensor of shape (R, C...) and row counts n1, n2. . nK, return a tensor of shape (K, C...) whose 1st row is the sum of the +# first n1 rows of the input, 2nd row is the sum of the next n2 rows etc +# note that this works for arbitrary C, including empty. That is, it works for 1D, 2D, 3D etc input. +def sums_over_rows(input_tensor: torch.Tensor, counts: torch.IntTensor): + range_ends = torch.cumsum(counts, dim=0) + assert range_ends[-1] == len(input_tensor) # the counts need to add up! + + row_cumsums = torch.cumsum(input_tensor, dim=0) + + # if counts are eg 1, 2, 3 then range ends are 1, 3, 6 and we are interested in cumsums[0, 2, 5] + relevant_cumsums = row_cumsums[(range_ends - 1).long()] + + # if counts are eg 1, 2, 3 we now have, the sum of the first 1, 3, and 6 rows. To get the sums of row 0, rows 1-2, rows 3-5 + # we need the consecutive differences, with a row of zeroes prepended + row_of_zeroes = torch.zeros_like(relevant_cumsums[0])[None] # the [None] makes it (1xC) + relevant_sums = torch.diff(relevant_cumsums, dim=0, prepend=row_of_zeroes) + return relevant_sums + + +# same but divide by the counts to get means +def means_over_rows(input_tensor: torch.Tensor, counts: torch.IntTensor, keepdim: bool = False): + extra_dims = (1,) * (input_tensor.dim() - 1) + result = sums_over_rows(input_tensor, counts) / counts.view(-1, *extra_dims) + + return torch.repeat_interleave(result, dim=0, repeats=counts) if keepdim else result + + +# given 3d tensor T_ijk and 1D index tensors I, J, K, return the 1D tensor: +# result[n] = T[I[n], J[n], K[n]] +def index_3d_array(tens, idx0, idx1, idx2): + dim0, dim1, dim2 = tens.shape + flattened_indices = (idx0 * dim1 * dim2) + (idx1 * dim2) + idx2 + return tens.view(-1)[flattened_indices] + + +def index_2d_array(tens, idx0, idx1): + dim0, dim1 = tens.shape + flattened_indices = (idx0 * dim1) + idx1 + return tens.view(-1)[flattened_indices] + + +# same but include a regularizer in case of zeros in the counts vector +# regularizer has the dimension of one row of the input tensor +def means_over_rows_with_regularizer(input_tensor: torch.Tensor, counts: torch.IntTensor, regularizer, regularizer_weight, keepdim: bool = False): + extra_dims = (1,) * (input_tensor.dim() - 1) + + regularized_sums = sums_over_rows(input_tensor, counts) + (regularizer_weight * regularizer)[None, :] + regularized_counts = counts + regularizer_weight + result = regularized_sums / regularized_counts.view(-1, *extra_dims) + + return torch.repeat_interleave(result, dim=0, repeats=counts) if keepdim else result + + +class StreamingAverage: + def __init__(self): + self._count = 0.0 + self._sum = 0.0 + + def is_empty(self) -> bool: + return self._count == 0.0 + + def get(self) -> float: + return self._sum / (self._count + 0.0001) + + def record(self, value: float, weight: float=1): + self._count += weight + self._sum += value * weight + + def record_sum(self, value_sum: float, count): + self._count += count + self._sum += value_sum + + # record only values masked as true + def record_with_mask(self, values: torch.Tensor, mask: torch.Tensor): + self._count += torch.sum(mask).item() + self._sum += torch.sum(values*mask).item() + + # record values with different weights + # values and mask should live on same device as self._sum + def record_with_weights(self, values: torch.Tensor, weights: torch.Tensor): + self._count += torch.sum(weights).item() + self._sum += torch.sum(values * weights).item() + + +def log_binomial_coefficient(n: torch.Tensor, k: torch.Tensor): + return (n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma() + + +def backpropagate(optimizer: torch.optim.Optimizer, loss: torch.Tensor): + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + + +def find_variant_type(v: cyvcf2.Variant): + alt = v.ALT[0] # TODO: we're assuming biallelic + ref = v.REF + return Variation.get_type(ref, alt) + + +def extract_to_temp_dir(tar_file, directory): + tar = tarfile.open(tar_file) + tar.extractall(directory) + tar.close() + return [os.path.abspath(os.path.join(directory, p)) for p in os.listdir(directory)] + + +def trim_alleles_on_right(ref: str, alt: str): + trimmed_ref, trimmed_alt = ref, alt + while len(trimmed_ref) > 1 and len(trimmed_alt) > 1 and trimmed_alt[-1] == trimmed_ref[-1]: + trimmed_ref, trimmed_alt = trimmed_ref[:-1], trimmed_alt[:-1] + return trimmed_ref, trimmed_alt \ No newline at end of file diff --git a/src/main/python/org/broadinstitute/hellbender/setup_permutext.py b/src/main/python/org/broadinstitute/hellbender/setup_permutext.py new file mode 100644 index 00000000000..4d277a2c364 --- /dev/null +++ b/src/main/python/org/broadinstitute/hellbender/setup_permutext.py @@ -0,0 +1,22 @@ +from setuptools import setup, find_packages + +setup( + name="Mutect 3", + version="0.1", + author="David Benjamin", + author_email="davidben@broadinstitute.org", + description="A new way to filter somatic variant calls", + license="Apache license version 2.0", + packages=find_packages(), + entry_points={ + 'console_scripts': ['train_model=permutect.tools.train_model:main', + 'train_base_model=permutect.tools.train_base_model:main', + 'filter_variants=permutect.tools.filter_variants:main', + 'preprocess_dataset=permutect.tools.preprocess_dataset:main', + 'edit_dataset=permutect.tools.edit_dataset:main', + 'prune_dataset=permutect.tools.prune_dataset:main', + 'evaluate_model=permutect.tools.evaluate_model:main', + 'compare_to_mutect2=permutect.tools.compare_to_mutect2:main' + ] + } +) diff --git a/src/main/resources/org/broadinstitute/hellbender/tools/permutect/edit_dataset.py b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/edit_dataset.py new file mode 100644 index 00000000000..5e901de8df0 --- /dev/null +++ b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/edit_dataset.py @@ -0,0 +1,129 @@ +import argparse +import os +import tarfile +import tempfile +from enum import Enum + +import psutil +import torch.utils.data + +from permutect.data import base_datum +from tqdm.autonotebook import tqdm + +from permutect import constants +from permutect.data.base_dataset import BaseDataset +from permutect.utils import Label + + +class EditType(Enum): + UNLABEL_ARTIFACTS = "unlabel_artifacts" + UNLABEL_VARIANTS = "unlabel_variants" + UNLABEL_EVERYTHING = "unlabel_everything" + REMOVE_ARTIFACTS = "remove_artifacts" + REMOVE_VARIANTS = "remove_variants" + KEEP_EVERYTHING = "keep_everything" + + +# generates BaseDatum(s) from the original dataset that *pass* the pruning thresholds +def generate_edited_data(base_datasets, edit_type: str, source: int): + pbar = tqdm(enumerate(torch.utils.data.ConcatDataset(base_datasets)), mininterval=60) + + for n, base_datum in pbar: + if source is not None: + base_datum.set_source(source) + + if edit_type == EditType.UNLABEL_ARTIFACTS.value: + if base_datum.label == Label.ARTIFACT: + base_datum.set_label(Label.UNLABELED) + yield base_datum + elif edit_type == EditType.UNLABEL_VARIANTS.value: + if base_datum.label == Label.VARIANT: + base_datum.set_label(Label.UNLABELED) + yield base_datum + elif edit_type == EditType.UNLABEL_EVERYTHING.value: + base_datum.set_label(Label.UNLABELED) + yield base_datum + elif edit_type == EditType.REMOVE_ARTIFACTS.value: + if base_datum.label != Label.ARTIFACT: + yield base_datum + elif edit_type == EditType.REMOVE_VARIANTS.value: + if base_datum.label != Label.VARIANT: + yield base_datum + elif edit_type == EditType.KEEP_EVERYTHING.value: + yield base_datum + else: + raise Exception(f"edit type {edit_type} not implemented yet") + + +# takes a ReadSet generator and organizes into buffers. +# TODO: probably code duplication since the generator is already pruned +def generate_output_data_buffers(output_data_generator, max_bytes_per_chunk: int): + buffer, bytes_in_buffer = [], 0 + for datum in output_data_generator: + + buffer.append(datum) + bytes_in_buffer += datum.size_in_bytes() + if bytes_in_buffer > max_bytes_per_chunk: + print(f"Memory usage percent: {psutil.virtual_memory().percent:.1f}") + print(f"{bytes_in_buffer} bytes in chunk") + yield buffer + buffer, bytes_in_buffer = [], 0 + + # There will be some data left over, in general. + if buffer: + yield buffer + + +def make_output_training_dataset(pruned_data_buffer_generator, output_tarfile): + pruned_data_files = [] + for base_data_list in pruned_data_buffer_generator: + with tempfile.NamedTemporaryFile(delete=False) as train_data_file: + base_datum.save_list_base_data(base_data_list, train_data_file) + pruned_data_files.append(train_data_file.name) + + # bundle them in a tarfile + with tarfile.open(output_tarfile, "w") as train_tar: + for train_file in pruned_data_files: + train_tar.add(train_file, arcname=os.path.basename(train_file)) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='train the Mutect3 artifact model') + parser.add_argument('--' + constants.CHUNK_SIZE_NAME, type=int, default=int(2e9), required=False, + help='size in bytes of output binary data files') + parser.add_argument('--' + constants.DATASET_EDIT_TYPE_NAME, type=str, required=True, + help='how to modify the dataset') + parser.add_argument('--' + constants.SOURCE_NAME, type=int, required=False, help='new source integer to apply') + + # input / output + parser.add_argument('--' + constants.TRAIN_TAR_NAME, nargs='+', type=str, required=True, + help='tarfile(s) of training/validation datasets produced by preprocess_dataset.py') + parser.add_argument('--' + constants.OUTPUT_NAME, type=str, required=True, help='path to pruned dataset file') + + return parser.parse_args() + + +def main_without_parsing(args): + original_tarfiles = getattr(args, constants.TRAIN_TAR_NAME) # list of files + output_tarfile = getattr(args, constants.OUTPUT_NAME) + chunk_size = getattr(args, constants.CHUNK_SIZE_NAME) + edit_type = getattr(args, constants.DATASET_EDIT_TYPE_NAME) + new_source = getattr(args, constants.SOURCE_NAME) + base_datasets = map(lambda original_tarfile: BaseDataset(data_tarfile=original_tarfile), original_tarfiles) + + # generate ReadSets + output_data_generator = generate_edited_data(base_datasets, edit_type, new_source) + + # generate List[ReadSet]s + output_data_buffer_generator = generate_output_data_buffers(output_data_generator, chunk_size) + + make_output_training_dataset(output_data_buffer_generator, output_tarfile=output_tarfile) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() diff --git a/src/main/resources/org/broadinstitute/hellbender/tools/permutect/evaluate_model.py b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/evaluate_model.py new file mode 100644 index 00000000000..2631418d385 --- /dev/null +++ b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/evaluate_model.py @@ -0,0 +1,42 @@ +import argparse +from torch.utils.tensorboard import SummaryWriter +from permutect import constants, utils +from permutect.data.base_dataset import BaseDataset +from permutect.architecture.artifact_model import load_artifact_model + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--' + constants.EVALUATION_TAR_NAME, type=str, required=True, + help='tarfile of evaluation datasets produced by preprocess_dataset.py') + parser.add_argument('--' + constants.M3_MODEL_NAME, required=True, help='trained Mutect3 artifact model from train_model.py') + parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, help='path to output tensorboard') + parser.add_argument('--' + constants.BATCH_SIZE_NAME, type=int, default=64, required=False, help='batch size') + parser.add_argument('--' + constants.NUM_WORKERS_NAME, type=int, default=0, required=False, + help='number of subprocesses devoted to data loading, which includes reading from memory map, ' + 'collating batches, and transferring to GPU.') + + return parser.parse_args() + + +def main_without_parsing(args): + data_tarfile = getattr(args, constants.EVALUATION_TAR_NAME) + saved_artifact_model = getattr(args, constants.M3_MODEL_NAME) + tensorboard_dir = getattr(args, constants.TENSORBOARD_DIR_NAME) + batch_size = getattr(args, constants.BATCH_SIZE_NAME) + num_workers = getattr(args, constants.NUM_WORKERS_NAME) + + dataset = BaseDataset(data_tarfile=data_tarfile, num_folds=10) + device = utils.gpu_if_available() + artifact_model, _, _ = load_artifact_model(saved_artifact_model, device=device) + summary_writer = SummaryWriter(tensorboard_dir) + artifact_model.evaluate_model_after_training(dataset, batch_size, num_workers, summary_writer) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/main/resources/org/broadinstitute/hellbender/tools/permutect/filter_variants.py b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/filter_variants.py new file mode 100644 index 00000000000..9a50c2ad1fa --- /dev/null +++ b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/filter_variants.py @@ -0,0 +1,399 @@ +import argparse +import math +from collections import defaultdict +from typing import Set + +import cyvcf2 +import psutil +import torch +from intervaltree import IntervalTree +from torch.utils.tensorboard import SummaryWriter +from tqdm.autonotebook import tqdm + +from permutect import constants, utils +from permutect.architecture.artifact_model import ArtifactModel, load_base_model_and_artifact_model +from permutect.architecture.posterior_model import PosteriorModel +from permutect.architecture.base_model import BaseModel +from permutect.data import base_dataset, plain_text_data, base_datum +from permutect.data.base_datum import Variant +from permutect.data.posterior import PosteriorDataset, PosteriorDatum, PosteriorBatch +from permutect.data.artifact_dataset import ArtifactDataset +from permutect.metrics.evaluation_metrics import EvaluationMetrics, PosteriorResult, EmbeddingMetrics, \ + round_up_to_nearest_three, MAX_COUNT +from permutect.utils import Call, find_variant_type, Label, Variation, Epoch, trim_alleles_on_right + +TRUSTED_M2_FILTERS = {'contamination'} + +POST_PROB_INFO_KEY = 'POST' +ARTIFACT_LOD_INFO_KEY = 'ARTLOD' +LOG_PRIOR_INFO_KEY = 'PRIOR' +SPECTRA_LOG_LIKELIHOOD_INFO_KEY = 'SPECLL' +NORMAL_LOG_LIKELIHOOD_INFO_KEY = 'NORMLL' + +FILTER_NAMES = [call_type.name.lower() for call_type in Call] + + +# the inverse of the sigmoid function. Convert a probability to a logit. +def prob_to_logit(prob: float): + clipped_prob = 0.5 + 0.9999999 * (prob - 0.5) + return math.log(clipped_prob / (1 - clipped_prob)) + + +def get_first_numeric_element(variant, key): + tuple_or_scalar = variant.INFO[key] + return tuple_or_scalar[0] if type(tuple_or_scalar) is tuple else tuple_or_scalar + + +# if alt and ref alleles are not in minimal representation ie have redundant matching bases at the end, trim them + + +# TODO: contigs stored as integer index must be converted back to string to compare VCF variants with dataset variants!!! +def encode(contig: str, position: int, ref: str, alt: str): + trimmed_ref, trimmed_alt = trim_alleles_on_right(ref, alt) + return contig + ':' + str(position) + ':' + base_datum.truncate_bases_if_necessary(trimmed_alt) + + +def encode_datum(variant: Variant, contig_index_to_name_map): + contig_name = contig_index_to_name_map[variant.contig] + return encode(contig_name, variant.position, variant.ref, variant.alt) + + +def encode_variant(v: cyvcf2.Variant, zero_based=False): + alt = v.ALT[0] # TODO: we're assuming biallelic + ref = v.REF + start = (v.start + 1) if zero_based else v.start + return encode(v.CHROM, start, ref, alt) + + +def filters_to_keep_from_m2(v: cyvcf2.Variant) -> Set[str]: + return set([]) if v.FILTER is None else set(v.FILTER.split(";")).intersection(TRUSTED_M2_FILTERS) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--' + constants.INPUT_NAME, required=True, help='unfiltered input Mutect2 VCF') + parser.add_argument('--' + constants.TEST_DATASET_NAME, required=True, + help='plain text dataset file corresponding to variants in input VCF') + parser.add_argument('--' + constants.M3_MODEL_NAME, required=True, help='trained Permutect model from train_model.py') + parser.add_argument('--' + constants.CONTIGS_TABLE_NAME, required=True, help='table of contig names vs integer indices') + parser.add_argument('--' + constants.OUTPUT_NAME, required=True, help='path to output filtered VCF') + parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, help='path to output tensorboard') + parser.add_argument('--' + constants.BATCH_SIZE_NAME, type=int, default=64, required=False, help='batch size') + parser.add_argument('--' + constants.NUM_WORKERS_NAME, type=int, default=0, required=False, + help='number of subprocesses devoted to data loading, which includes reading from memory map, ' + 'collating batches, and transferring to GPU.') + parser.add_argument('--' + constants.CHUNK_SIZE_NAME, type=int, default=100000, required=False, help='size in bytes of intermediate binary datasets') + parser.add_argument('--' + constants.NUM_SPECTRUM_ITERATIONS_NAME, type=int, default=10, required=False, + help='number of epochs for fitting allele fraction spectra') + parser.add_argument('--' + constants.SPECTRUM_LEARNING_RATE_NAME, type=float, default=0.001, required=False, + help='learning rate for fitting allele fraction spectra') + parser.add_argument('--' + constants.INITIAL_LOG_VARIANT_PRIOR_NAME, type=float, default=-10.0, required=False, + help='initial value for natural log prior of somatic variants') + parser.add_argument('--' + constants.INITIAL_LOG_ARTIFACT_PRIOR_NAME, type=float, default=-10.0, required=False, + help='initial value for natural log prior of artifacts') + parser.add_argument('--' + constants.GENOMIC_SPAN_NAME, type=float, required=True, + help='number of sites considered by Mutect2, including those lacking variation or artifacts, hence absent from input dataset. ' + 'Necessary for learning priors since otherwise rates of artifacts and variants would be overinflated.') + parser.add_argument('--' + constants.MAF_SEGMENTS_NAME, required=False, + help='copy-number segmentation file from GATK containing minor allele fractions. ' + 'Useful for modeling germline variation as the minor allele fraction determines the distribution of germline allele counts.') + parser.add_argument('--' + constants.NORMAL_MAF_SEGMENTS_NAME, required=False, + help='copy-number segmentation file from GATK containing minor allele fractions in the normal/control sample') + + parser.add_argument('--' + constants.GERMLINE_MODE_NAME, action='store_true', + help='flag for genotyping both somatic and somatic variants distinctly but considering both ' + 'as non-errors (true positives), which affects the posterior threshold set by optimal F1 score') + parser.add_argument('--' + constants.HET_BETA_NAME, type=float, required=False, + help='beta shape parameter for germline spectrum beta binomial if we want to override binomial') + + parser.add_argument('--' + constants.NO_GERMLINE_MODE_NAME, action='store_true', + help='flag for not genotyping germline events so that the only possibilities considered are ' + 'somatic, artifact, and sequencing error. This is useful for certain validation where ' + 'pseudo-somatic events are created by mixing germline events at varying fractions') + return parser.parse_args() + + +def get_segmentation(segments_file) -> defaultdict: + + result = defaultdict(IntervalTree) + if segments_file is None: + return result + + print("reading segmentation file") + with open(segments_file, 'r') as file: + for line in file: + if line.startswith("#") or (line.startswith("contig") and "minor_allele_fraction" in line): + continue + tokens = line.split() + contig, start, stop, maf = tokens[0], int(tokens[1]), int(tokens[2]), float(tokens[3]) + if stop > start: # IntervalTree throws error if start == stop + result[contig][start:stop] = maf + + return result + + +def main_without_parsing(args): + make_filtered_vcf(saved_artifact_model_path=getattr(args, constants.M3_MODEL_NAME), + initial_log_variant_prior=getattr(args, constants.INITIAL_LOG_VARIANT_PRIOR_NAME), + initial_log_artifact_prior=getattr(args, constants.INITIAL_LOG_ARTIFACT_PRIOR_NAME), + test_dataset_file=getattr(args, constants.TEST_DATASET_NAME), + contigs_table=getattr(args, constants.CONTIGS_TABLE_NAME), + input_vcf=getattr(args, constants.INPUT_NAME), + output_vcf=getattr(args, constants.OUTPUT_NAME), + batch_size=getattr(args, constants.BATCH_SIZE_NAME), + num_workers=getattr(args, constants.NUM_WORKERS_NAME), + chunk_size=getattr(args, constants.CHUNK_SIZE_NAME), + num_spectrum_iterations=getattr(args, constants.NUM_SPECTRUM_ITERATIONS_NAME), + spectrum_learning_rate=getattr(args, constants.SPECTRUM_LEARNING_RATE_NAME), + tensorboard_dir=getattr(args, constants.TENSORBOARD_DIR_NAME), + genomic_span=getattr(args, constants.GENOMIC_SPAN_NAME), + germline_mode=getattr(args, constants.GERMLINE_MODE_NAME), + no_germline_mode=getattr(args, constants.NO_GERMLINE_MODE_NAME), + het_beta=getattr(args, constants.HET_BETA_NAME), + segmentation=get_segmentation(getattr(args, constants.MAF_SEGMENTS_NAME)), + normal_segmentation=get_segmentation(getattr(args, constants.NORMAL_MAF_SEGMENTS_NAME))) + + +def make_filtered_vcf(saved_artifact_model_path, initial_log_variant_prior: float, initial_log_artifact_prior: float, + test_dataset_file, contigs_table, input_vcf, output_vcf, batch_size: int, num_workers: int, chunk_size: int, num_spectrum_iterations: int, + spectrum_learning_rate: float, tensorboard_dir, genomic_span: int, germline_mode: bool = False, no_germline_mode: bool = False, het_beta: float = None, + segmentation=defaultdict(IntervalTree), normal_segmentation=defaultdict(IntervalTree)): + print("Loading artifact model and test dataset") + contig_index_to_name_map = {} + with open(contigs_table) as file: + while line := file.readline().strip(): + contig, index = line.split() + contig_index_to_name_map[int(index)] = contig + + device = utils.gpu_if_available() + base_model, artifact_model, artifact_log_priors, artifact_spectra_state_dict = \ + load_base_model_and_artifact_model(saved_artifact_model_path, device=device) + + posterior_model = PosteriorModel(initial_log_variant_prior, initial_log_artifact_prior, no_germline_mode=no_germline_mode, num_base_features=artifact_model.num_base_features, het_beta=het_beta) + posterior_data_loader = make_posterior_data_loader(test_dataset_file, input_vcf, contig_index_to_name_map, + base_model, artifact_model, batch_size, num_workers=num_workers, chunk_size=chunk_size, segmentation=segmentation, normal_segmentation=normal_segmentation) + + print("Learning AF spectra") + summary_writer = SummaryWriter(tensorboard_dir) + + num_ignored_sites = genomic_span - len(posterior_data_loader.dataset) + # here is where pretrained artifact priors and spectra are used if given + + posterior_model.learn_priors_and_spectra(posterior_data_loader, num_iterations=num_spectrum_iterations, + summary_writer=summary_writer, ignored_to_non_ignored_ratio=num_ignored_sites/len(posterior_data_loader.dataset), + learning_rate=spectrum_learning_rate) + + print("Calculating optimal logit threshold") + error_probability_thresholds = posterior_model.calculate_probability_thresholds(posterior_data_loader, summary_writer, germline_mode=germline_mode) + print(f"Optimal probability threshold: {error_probability_thresholds}") + apply_filtering_to_vcf(input_vcf, output_vcf, contig_index_to_name_map, error_probability_thresholds, posterior_data_loader, posterior_model, summary_writer=summary_writer, germline_mode=germline_mode) + + +@torch.inference_mode() +def make_posterior_data_loader(dataset_file, input_vcf, contig_index_to_name_map, base_model: BaseModel, artifact_model: ArtifactModel, + batch_size: int, num_workers: int, chunk_size: int, segmentation=defaultdict(IntervalTree), normal_segmentation=defaultdict(IntervalTree)): + print("Reading test dataset") + + m2_filtering_to_keep = set() + allele_frequencies = {} + + print("recording M2 filters and allele frequencies from input VCF") + pbar = tqdm(enumerate(cyvcf2.VCF(input_vcf)), mininterval=60) + for n, v in pbar: + encoding = encode_variant(v, zero_based=True) + if filters_to_keep_from_m2(v): + m2_filtering_to_keep.add(encoding) + allele_frequencies[encoding] = 10 ** (-get_first_numeric_element(v, "POPAF")) + + # pass through the plain text dataset, normalizing and creating ReadSetDatasets as we go, running the artifact model + # to get artifact logits, which we record in a dict keyed by variant strings. These will later be added to PosteriorDatum objects. + print("reading dataset and calculating artifact logits") + print(f"Memory usage percent before loading data: {psutil.virtual_memory().percent:.1f}") + posterior_data = [] + for list_of_base_data in plain_text_data.generate_normalized_data([dataset_file], chunk_size): + print(f"Memory usage percent before creating BaseDataset: {psutil.virtual_memory().percent:.1f}") + raw_dataset = base_dataset.BaseDataset(data_in_ram=list_of_base_data) + print(f"Memory usage percent before creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") + artifact_dataset = ArtifactDataset(raw_dataset, base_model) + print(f"Memory usage percent after creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") + artifact_loader = artifact_dataset.make_data_loader(artifact_dataset.all_folds(), batch_size, pin_memory=torch.cuda.is_available(), num_workers=num_workers) + + print("creating posterior data for this chunk...") + pbar = tqdm(enumerate(artifact_loader), mininterval=60) + for n, artifact_batch_cpu in pbar: + artifact_batch = artifact_batch_cpu.copy_to(device=artifact_model._device, dtype=artifact_model._dtype, non_blocking=artifact_model._device.type == 'cuda') + artifact_logits, _, _ = artifact_model.forward(batch=artifact_batch) + + labels = [(Label.ARTIFACT if label > 0.5 else Label.VARIANT) if is_labeled > 0.5 else Label.UNLABELED for (label, is_labeled) in zip(artifact_batch.get_training_labels(), artifact_batch.get_is_labeled_mask())] + + for variant,counts_and_seq_lks, logit, label, embedding in zip(artifact_batch_cpu.get_variants(), + artifact_batch_cpu.get_counts_and_seq_lks(), + artifact_logits.detach().tolist(), + labels, + artifact_batch.get_representations_2d().cpu()): + contig_name = contig_index_to_name_map[variant.contig] + encoding = encode(contig_name, variant.position, variant.ref, variant.alt) + if encoding in allele_frequencies and encoding not in m2_filtering_to_keep: + allele_frequency = allele_frequencies[encoding] + + # these are default dicts, so if there's no segmentation for the contig we will get no overlaps but not an error + # For a general IntervalTree there is a list of potentially multiple overlaps but here there is either one or zero + segmentation_overlaps = segmentation[contig_name][variant.position] + normal_segmentation_overlaps = normal_segmentation[contig_name][variant.position] + maf = list(segmentation_overlaps)[0].data if segmentation_overlaps else 0.5 + normal_maf = list(normal_segmentation_overlaps)[0].data if normal_segmentation_overlaps else 0.5 + + posterior_datum = PosteriorDatum(variant, counts_and_seq_lks, allele_frequency, logit, embedding, label, maf, normal_maf) + posterior_data.append(posterior_datum) + + print(f"Size of filtering dataset: {len(posterior_data)}") + posterior_dataset = PosteriorDataset(posterior_data) + print(f"Memory usage percent after creating PosteriorDataset: {psutil.virtual_memory().percent:.1f}") + return posterior_dataset.make_data_loader(batch_size, pin_memory=torch.cuda.is_available(), num_workers=num_workers) + + +# error probability thresholds is a dict from Variant type to error probability threshold (float) +@torch.inference_mode() +def apply_filtering_to_vcf(input_vcf, output_vcf, contig_index_to_name_map, error_probability_thresholds, + posterior_loader, posterior_model, summary_writer: SummaryWriter, germline_mode: bool = False): + print("Computing final error probabilities") + passing_call_type = Call.GERMLINE if germline_mode else Call.SOMATIC + encoding_to_posterior_results = {} + + pbar = tqdm(enumerate(posterior_loader), mininterval=60) + batch_cpu: PosteriorBatch + for n, batch_cpu in pbar: + batch = batch_cpu.copy_to(device=posterior_model._device, dtype=posterior_model._dtype, non_blocking=posterior_model._device.type == 'cuda') + # posterior, along with intermediate tensors for debugging/interpretation + log_priors, spectra_lls, normal_lls, log_posteriors = \ + posterior_model.log_posterior_and_ingredients(batch) + + posterior_probs = torch.nn.functional.softmax(log_posteriors, dim=1) + + encodings = [encode_datum(variant, contig_index_to_name_map) for variant in batch.get_variants()] + artifact_logits = batch.get_artifact_logits().tolist() + var_types = batch.get_variant_types().tolist() + labels = batch.get_labels().tolist() + alt_counts = batch.get_alt_counts().tolist() + depths = batch.get_depths().tolist() + + for encoding, post_probs, logit, log_prior, log_spec, log_normal, label, alt_count, depth, var_type, embedding in zip(encodings, posterior_probs, artifact_logits, log_priors, spectra_lls, normal_lls, labels, alt_counts, depths, var_types, batch.embeddings): + encoding_to_posterior_results[encoding] = PosteriorResult(logit, post_probs.tolist(), log_prior, log_spec, log_normal, label, alt_count, depth, var_type, embedding) + + print("Applying threshold") + unfiltered_vcf = cyvcf2.VCF(input_vcf) + + all_types = [call_type.name for call_type in Call] + unfiltered_vcf.add_format_to_header( {'ID': "DP", 'Description': "depth", 'Type': 'Integer', 'Number': '1'}) + unfiltered_vcf.add_info_to_header({'ID': POST_PROB_INFO_KEY, 'Description': 'Mutect3 posterior probability of {' + ', '.join(all_types) + '}', + 'Type': 'Float', 'Number': 'A'}) + unfiltered_vcf.add_info_to_header({'ID': LOG_PRIOR_INFO_KEY, 'Description': 'Log priors of {' + ', '.join(all_types) + '}', + 'Type': 'Float', 'Number': 'A'}) + unfiltered_vcf.add_info_to_header({'ID': SPECTRA_LOG_LIKELIHOOD_INFO_KEY, 'Description': 'Log spectra likelihoods of {' + ', '.join(all_types) + '}', + 'Type': 'Float', 'Number': 'A'}) + unfiltered_vcf.add_info_to_header({'ID': NORMAL_LOG_LIKELIHOOD_INFO_KEY, 'Description': 'Log normal likelihoods of {' + ', '.join(all_types) + '}', + 'Type': 'Float', 'Number': 'A'}) + unfiltered_vcf.add_info_to_header({'ID': ARTIFACT_LOD_INFO_KEY, 'Description': 'Mutect3 artifact log odds', + 'Type': 'Float', 'Number': 'A'}) + + for n, filter_name in enumerate(FILTER_NAMES): + if n != passing_call_type: + unfiltered_vcf.add_filter_to_header({'ID': filter_name, 'Description': filter_name}) + + writer = cyvcf2.Writer(output_vcf, unfiltered_vcf) # input vcf is a template for the header + evaluation_metrics = EvaluationMetrics() + pbar = tqdm(enumerate(unfiltered_vcf), mininterval=60) + labeled_truth = False + embedding_metrics = EmbeddingMetrics() # only if there is labeled truth for evaluation + + missing_encodings = [] + for n, v in pbar: + filters = filters_to_keep_from_m2(v) + + # TODO: in germline mode, somatic doesn't exist (or is just highly irrelevant) and germline is not an error! + encoding = encode_variant(v, zero_based=True) # cyvcf2 is zero-based + if encoding in encoding_to_posterior_results: + posterior_result = encoding_to_posterior_results[encoding] + post_probs = posterior_result.posterior_probabilities + v.INFO[POST_PROB_INFO_KEY] = ','.join(map(lambda prob: "{:.3f}".format(prob), post_probs)) + v.INFO[LOG_PRIOR_INFO_KEY] = ','.join(map(lambda pri: "{:.3f}".format(pri), posterior_result.log_priors)) + v.INFO[SPECTRA_LOG_LIKELIHOOD_INFO_KEY] = ','.join(map(lambda ll: "{:.3f}".format(ll), posterior_result.spectra_lls)) + v.INFO[ARTIFACT_LOD_INFO_KEY] = "{:.3f}".format(posterior_result.artifact_logit) + v.INFO[NORMAL_LOG_LIKELIHOOD_INFO_KEY] = ','.join(map(lambda ll: "{:.3f}".format(ll), posterior_result.normal_lls)) + + label = Label(posterior_result.label) # this is the Label enum, might be UNLABELED + error_prob = 1 - post_probs[passing_call_type] + variant_type = find_variant_type(v) + called_as_error = error_prob > error_probability_thresholds[variant_type] + + error_call = None + + if called_as_error: + # get the error type with the largest posterior probability + highest_prob_indices = torch.topk(torch.Tensor(post_probs), 2).indices.tolist() + highest_prob_index = highest_prob_indices[1] if highest_prob_indices[0] == passing_call_type else highest_prob_indices[0] + error_call = list(Call)[highest_prob_index] + filters.add(FILTER_NAMES[highest_prob_index]) + + # note that this excludes the correctness part of embedding metrics, which is below + embedding_metrics.label_metadata.append(label.name) + embedding_metrics.type_metadata.append(variant_type.name) + embedding_metrics.truncated_count_metadata.append( + str(round_up_to_nearest_three(min(MAX_COUNT, posterior_result.alt_count)))) + embedding_metrics.representations.append(posterior_result.embedding) + + correctness_label = "unknown" + if label != Label.UNLABELED: + labeled_truth = True + clipped_error_prob = 0.5 + 0.9999999 * (error_prob - 0.5) + error_logit = prob_to_logit(clipped_error_prob) + float_label = 1.0 if label == Label.ARTIFACT else 0.0 + + # TODO: this is sloppy -- it only works because when we label the posterior dataset (if truth is available) + # TODO: we stretch the definitions so that "Label.ARTIFACT" simply means "something we shouldn't call", including + # TODO: artifact or germline (in the somatic calling case), and "Label.VARIANT" means "something we should call" + is_correct = (called_as_error and label == Label.ARTIFACT) or (not called_as_error and label == Label.VARIANT) + evaluation_metrics.record_call(Epoch.TEST, variant_type, error_logit, float_label, is_correct, posterior_result.alt_count) + + # TODO: double-check the logic here + if is_correct: + if label == Label.VARIANT: + correctness_label = EmbeddingMetrics.TRUE_POSITIVE + elif error_call == Call.ARTIFACT or error_call == Call.NORMAL_ARTIFACT: + correctness_label = EmbeddingMetrics.TRUE_NEGATIVE_ARTIFACT + #elif error_call == Call.SEQ_ERROR: + # correctness_label = EmbeddingMetrics.TRUE_NEGATIVE_SEQ_ERROR + # we don't do anything for germline (in somatic mode) or seq error -- + else: + if called_as_error: + if error_call == Call.ARTIFACT or error_call == Call.NORMAL_ARTIFACT: + correctness_label = EmbeddingMetrics.FALSE_NEGATIVE_ARTIFACT + else: + correctness_label = EmbeddingMetrics.FALSE_POSITIVE + # TODO: this is only right for somatic calling + bad_call = error_call if called_as_error else Call.SOMATIC + evaluation_metrics.record_mistake(posterior_result, bad_call) + embedding_metrics.correct_metadata.append(correctness_label) + else: + missing_encodings.append(encoding) + v.FILTER = ';'.join(filters) if filters else 'PASS' + writer.write_record(v) + print("closing resources") + writer.close() + unfiltered_vcf.close() + + embedding_metrics.output_to_summary_writer(summary_writer, is_filter_variants=True) + + if labeled_truth: + given_thresholds = {var_type: prob_to_logit(error_probability_thresholds[var_type]) for var_type in Variation} + evaluation_metrics.make_plots(summary_writer, given_thresholds, sens_prec=True) + evaluation_metrics.make_mistake_histograms(summary_writer) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() diff --git a/src/main/resources/org/broadinstitute/hellbender/tools/permutect/preprocess_dataset.py b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/preprocess_dataset.py new file mode 100644 index 00000000000..1eda5e64392 --- /dev/null +++ b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/preprocess_dataset.py @@ -0,0 +1,67 @@ +import argparse +import os +import tarfile +import tempfile +from typing import List + +from permutect import constants +from permutect.data import base_datum +from permutect.data.plain_text_data import generate_normalized_data +from permutect.utils import ConsistentValue + +""" +This tool takes as input a list of text file Mutect3 training datasets, reads them in chunks that fit in memory, +normalizes each chunk, outputs each chunk as a binary PyTorch file, and bundles the output as a tarfile. +""" + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='preprocess plain text training dataset into tarfile of nprmalized binary data') + parser.add_argument('--' + constants.TRAINING_DATASETS_NAME, nargs='+', type=str, required=True, + help='list of plain text data files') + parser.add_argument('--' + constants.CHUNK_SIZE_NAME, type=int, default=int(2e9), required=False, + help='size in bytes of output binary data files') + parser.add_argument('--' + constants.SOURCES_NAME, nargs='+', type=int, required=False, + help='integer sources corresponding to plain text data files for distinguishing different sequencing conditions') + parser.add_argument('--' + constants.OUTPUT_NAME, type=str, default=None, required=True, + help='path to output tarfile of training data') + return parser.parse_args() + + +def do_work(training_datasets, training_output_file, chunk_size, sources: List[int]): + data_files = [] + num_read_features, num_info_features, ref_sequence_length = ConsistentValue(), ConsistentValue(), ConsistentValue() + + # save all the lists of read sets to tempfiles. . . + # TODO: left off here. Need to give it sources, which will need to be command line argument + for base_data_list in generate_normalized_data(training_datasets, max_bytes_per_chunk=chunk_size, sources=sources): + num_read_features.check(base_data_list[0].get_reads_2d().shape[1]) + num_info_features.check(base_data_list[0].get_info_tensor_1d().shape[0]) + ref_sequence_length.check(base_data_list[0].get_ref_sequence_1d().shape[0]) + + with tempfile.NamedTemporaryFile(delete=False) as train_data_file: + base_datum.save_list_base_data(base_data_list, train_data_file) + data_files.append(train_data_file.name) + + # . . . and bundle them in a tarfile + with tarfile.open(training_output_file, "w") as train_tar: + for train_file in data_files: + train_tar.add(train_file, arcname=os.path.basename(train_file)) + + +def main_without_parsing(args): + chunk_size = getattr(args, constants.CHUNK_SIZE_NAME) + training_datasets = getattr(args, constants.TRAINING_DATASETS_NAME) + output_file = getattr(args, constants.OUTPUT_NAME) + sources = getattr(args, constants.SOURCES_NAME) + + do_work(training_datasets, output_file, chunk_size, sources) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() diff --git a/src/main/resources/org/broadinstitute/hellbender/tools/permutect/prune_dataset.py b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/prune_dataset.py new file mode 100644 index 00000000000..fd6d1d292f9 --- /dev/null +++ b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/prune_dataset.py @@ -0,0 +1,249 @@ +import argparse +import os +import tarfile +import tempfile +from typing import List + +import psutil + +from permutect.architecture.base_model import load_base_model, BaseModel +from permutect.data import base_datum +from tqdm.autonotebook import tqdm + +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from permutect import constants, utils +from permutect.architecture.artifact_model import ArtifactModel +from permutect.data.base_datum import ArtifactDatum, ArtifactBatch +from permutect.data.artifact_dataset import ArtifactDataset +from permutect.parameters import ArtifactModelParameters, parse_artifact_model_params, \ + add_artifact_model_params_to_parser, add_training_params_to_parser +from permutect.data.base_dataset import BaseDataset +from permutect.tools.train_model import TrainingParameters, parse_training_params +from permutect.utils import Label + +NUM_FOLDS = 3 + + +# labeled only pruning loader must be constructed with options to emit batches of all-labeled data +def calculate_pruning_thresholds(labeled_only_pruning_loader, artifact_model: ArtifactModel, label_art_frac: float, training_params: TrainingParameters) -> List[int]: + for fold in range(NUM_FOLDS): + average_artifact_confidence, average_nonartifact_confidence = utils.StreamingAverage(), utils.StreamingAverage() + # TODO: eventually this should all be segregated by variant type and maybe also alt count + + # the 0th/1st element is a list of predicted probabilities that data labeled as non-artifact/artifact are actually non-artifact/artifact + probs_of_agreeing_with_label = [[],[]] + print("calculating average confidence and gathering predicted probabilities") + pbar = tqdm(enumerate(labeled_only_pruning_loader), mininterval=60) + for n, batch in pbar: + # TODO: should we use likelihoods as in evaluation or posteriors as in training??? + # TODO: does it even matter?? + art_logits, _, _ = artifact_model.forward(batch) + art_probs = torch.sigmoid(art_logits.detach()) + + labels = batch.get_training_labels() + art_label_mask = (labels > 0.5) + nonart_label_mask = (labels < 0.5) + average_artifact_confidence.record_with_mask(art_probs, art_label_mask) + average_nonartifact_confidence.record_with_mask(1 - art_probs, nonart_label_mask) + + for art_prob, labeled_as_art in zip(art_probs.tolist(), art_label_mask.tolist()): + agreement_prob = art_prob if labeled_as_art else (1 - art_prob) + probs_of_agreeing_with_label[1 if labeled_as_art else 0].append(agreement_prob) + + # TODO: it is wasteful to run forward passes on all the data again when we can just record indices and logits + print("estimating error rates") + # The i,j element is the count of data labeled as i that pass the confidence threshold for j + # here 0 means non-artifact and 1 means artifact + confusion = [[0, 0], [0, 0]] + art_conf_threshold = average_artifact_confidence.get() + nonart_conf_threshold = average_nonartifact_confidence.get() + pbar = tqdm(enumerate(labeled_only_pruning_loader), mininterval=60) + for n, batch in pbar: + predicted_artifact_logits, _, _ = artifact_model.forward(batch) + predicted_artifact_probs = torch.sigmoid(predicted_artifact_logits.detach()) + + conf_art_mask = predicted_artifact_probs >= art_conf_threshold + conf_nonart_mask = (1 - predicted_artifact_probs) >= nonart_conf_threshold + art_label_mask = (batch.get_training_labels() > 0.5) + + for conf_artifact, conf_nonartifact, artifact_label in zip(conf_art_mask.tolist(), conf_nonart_mask.tolist(), art_label_mask.tolist()): + row = 1 if artifact_label else 0 + if conf_artifact: + confusion[row][1] += 1 + if conf_nonartifact: + confusion[row][0] += 1 + + # these are the probabilities of a true (hidden label) artifact/non-artifact being mislabeled as non-artifact/artifact + art_error_rate = confusion[0][1] / (confusion[0][1] + confusion[1][1]) + nonart_error_rate = confusion[1][0] / (confusion[0][0] + confusion[1][0]) + + # fraction of labeled data that are labeled as artifact + label_nonart_frac = 1 - label_art_frac + + # these are the inverse probabilities that something labeled as artifact/non-artifact was actually a mislabeled nonartifact/artifact + inv_art_error_rate = (nonart_error_rate / label_art_frac) * (label_nonart_frac - art_error_rate) / (1 - art_error_rate - nonart_error_rate) + inv_nonart_error_rate = (art_error_rate / label_nonart_frac) * (label_art_frac - nonart_error_rate) / (1 - art_error_rate - nonart_error_rate) + + print("Estimated error rates: ") + print(f"artifact mislabeled as non-artifact: {art_error_rate:.3f}") + print(f"non-artifact mislabeled as artifact: {nonart_error_rate:.3f}") + + print("Estimated inverse error rates: ") + print(f"Labeled artifact was actually non-artifact: {inv_art_error_rate:.3f}") + print(f"Labeled non-artifact was actually artifact: {inv_nonart_error_rate:.3f}") + + print("calculating rank pruning thresholds") + nonart_threshold = torch.quantile(torch.Tensor(probs_of_agreeing_with_label[0]), inv_nonart_error_rate).item() + art_threshold = torch.quantile(torch.Tensor(probs_of_agreeing_with_label[1]), inv_art_error_rate).item() + + print("Rank pruning thresholds: ") + print(f"Labeled artifacts are pruned if predicted artifact probability is less than {art_threshold:.3f}") + print(f"Labeled non-artifacts are pruned if predicted non-artifact probability is less than {nonart_threshold:.3f}") + + return art_threshold, nonart_threshold + + +# generates BaseDatum(s) from the original dataset that *pass* the pruning thresholds +def generated_pruned_data_for_fold(art_threshold: float, nonart_threshold: float, pruning_base_data_loader, + base_model: BaseModel, artifact_model: ArtifactModel) -> List[int]: + print("pruning the dataset") + pbar = tqdm(enumerate(pruning_base_data_loader), mininterval=60) + for n, base_batch in pbar: + # apply the representation model AND the artifact model to go from the original read set to artifact logits + representation, _ = base_model.calculate_representations(base_batch) + + artifact_batch = ArtifactBatch([ArtifactDatum(rs, rep) for rs, rep in zip(base_batch.original_list(), representation.detach())]) + + art_logits, _, _ = artifact_model.forward(artifact_batch) + art_probs = torch.sigmoid(art_logits.detach()) + art_label_mask = (base_batch.get_training_labels() > 0.5) + is_labeled_mask = (base_batch.get_is_labeled_mask() > 0.5) + + for art_prob, labeled_as_art, datum, is_labeled in zip(art_probs.tolist(), art_label_mask.tolist(), base_batch.original_list(), is_labeled_mask.tolist()): + if not is_labeled: + yield datum + elif (labeled_as_art and art_prob < art_threshold) or ((not labeled_as_art) and (1-art_prob) < nonart_threshold): + # TODO: process failing data, perhaps add option to output a pruned dataset? or flip labels? + pass + else: + yield datum # this is a ReadSet + + +def generate_pruned_data_for_all_folds(base_dataset: BaseDataset, base_model: BaseModel, + training_params: TrainingParameters, params: ArtifactModelParameters, tensorboard_dir): + # for each fold in turn, train an artifact model on all other folds and prune the chosen fold + use_gpu = torch.cuda.is_available() + device = torch.device('cuda' if use_gpu else 'cpu') + + for pruning_fold in range(NUM_FOLDS): + summary_writer = SummaryWriter(tensorboard_dir + "/fold_" + str(pruning_fold)) + print(f"Pruning data from fold {pruning_fold} of {NUM_FOLDS}") + print(f"Memory usage percent: {psutil.virtual_memory().percent:.3f}") + + # learn an artifact model with the pruning data held out + artifact_dataset = ArtifactDataset(base_dataset, base_model, base_dataset.all_but_one_fold(pruning_fold)) + + # sum is over variant types + label_art_frac = np.sum(artifact_dataset.totals[-1][Label.ARTIFACT]) / np.sum(artifact_dataset.totals[-1][Label.ARTIFACT] + + artifact_dataset.totals[-1][Label.VARIANT]) + + # learn pruning thresholds on the held-out data + pruning_artifact_dataset = ArtifactDataset(base_dataset, base_model, [pruning_fold]) + labeled_only_pruning_loader = pruning_artifact_dataset.make_data_loader(pruning_artifact_dataset.all_folds(), + training_params.batch_size, use_gpu, training_params.num_workers, labeled_only=True) + model = ArtifactModel(params=params, num_base_features=artifact_dataset.num_base_features, num_ref_alt_features=base_model.ref_alt_seq_embedding_dimension(), device=device).float() + model.learn(artifact_dataset, training_params, summary_writer=summary_writer) + + # TODO: maybe this should be done by variant type and/or count + art_threshold, nonart_threshold = calculate_pruning_thresholds(labeled_only_pruning_loader, model, label_art_frac, training_params) + + # unlike when learning thresholds, we load labeled and unlabeled data here + pruning_base_data_loader = base_dataset.make_data_loader([pruning_fold], training_params.batch_size, use_gpu, training_params.num_epochs) + for passing_base_datum in generated_pruned_data_for_fold(art_threshold, nonart_threshold, pruning_base_data_loader, base_model, model): + yield passing_base_datum + + +# takes a ReadSet generator and organies into buffers. +# TODO: probably code duplication since the generator is already pruned +def generate_pruned_data_buffers(pruned_data_generator, max_bytes_per_chunk: int): + buffer, bytes_in_buffer = [], 0 + for datum in pruned_data_generator: + + buffer.append(datum) + bytes_in_buffer += datum.size_in_bytes() + if bytes_in_buffer > max_bytes_per_chunk: + print(f"Memory usage percent: {psutil.virtual_memory().percent:.1f}") + print(f"{bytes_in_buffer} bytes in chunk") + yield buffer + buffer, bytes_in_buffer = [], 0 + + # There will be some data left over, in general. + if buffer: + yield buffer + + +def make_pruned_training_dataset(pruned_data_buffer_generator, pruned_tarfile): + pruned_data_files = [] + for base_data_list in pruned_data_buffer_generator: + with tempfile.NamedTemporaryFile(delete=False) as train_data_file: + base_datum.save_list_base_data(base_data_list, train_data_file) + pruned_data_files.append(train_data_file.name) + + # bundle them in a tarfile + with tarfile.open(pruned_tarfile, "w") as train_tar: + for train_file in pruned_data_files: + train_tar.add(train_file, arcname=os.path.basename(train_file)) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='train the Mutect3 artifact model') + + add_artifact_model_params_to_parser(parser) + add_training_params_to_parser(parser) + + parser.add_argument('--' + constants.CHUNK_SIZE_NAME, type=int, default=int(2e9), required=False, + help='size in bytes of output binary data files') + + # input / output + parser.add_argument('--' + constants.TRAIN_TAR_NAME, type=str, required=True, + help='tarfile of training/validation datasets produced by preprocess_dataset.py') + parser.add_argument('--' + constants.BASE_MODEL_NAME, type=str, help='Base model from train_base_model.py') + parser.add_argument('--' + constants.OUTPUT_NAME, type=str, required=True, help='path to pruned dataset file') + parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, + help='path to output tensorboard directory') + + return parser.parse_args() + + +def main_without_parsing(args): + params = parse_artifact_model_params(args) + training_params = parse_training_params(args) + + tensorboard_dir = getattr(args, constants.TENSORBOARD_DIR_NAME) + pruned_tarfile = getattr(args, constants.OUTPUT_NAME) + chunk_size = getattr(args, constants.CHUNK_SIZE_NAME) + original_tarfile = getattr(args, constants.TRAIN_TAR_NAME) + + base_model = load_base_model(getattr(args, constants.BASE_MODEL_NAME)) + base_dataset = BaseDataset(data_tarfile=original_tarfile, num_folds=NUM_FOLDS) + + # generate ReadSets passing pruning + pruned_data_generator = generate_pruned_data_for_all_folds(base_dataset, base_model, training_params, params, tensorboard_dir) + + # generate List[ReadSet]s passing pruning + pruned_data_buffer_generator = generate_pruned_data_buffers(pruned_data_generator, chunk_size) + + # save as a tarfile dataset + make_pruned_training_dataset(pruned_data_buffer_generator, pruned_tarfile=pruned_tarfile) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() diff --git a/src/main/resources/org/broadinstitute/hellbender/tools/permutect/train_base_model.py b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/train_base_model.py new file mode 100644 index 00000000000..4ba3e2bcb40 --- /dev/null +++ b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/train_base_model.py @@ -0,0 +1,60 @@ +import argparse + +from torch.utils.tensorboard import SummaryWriter + +from permutect import constants, utils +from permutect.architecture.base_model import BaseModel, LearningMethod, load_base_model, learn_base_model +from permutect.parameters import BaseModelParameters, TrainingParameters, parse_training_params, \ + parse_base_model_params, add_base_model_params_to_parser, add_training_params_to_parser +from permutect.data.base_dataset import BaseDataset + + +def train_base_model(params: BaseModelParameters, training_params: TrainingParameters, learning_method: LearningMethod, + summary_writer: SummaryWriter, dataset: BaseDataset, pretrained_model: BaseModel = None) -> BaseModel: + base_model = pretrained_model if (pretrained_model is not None) else \ + BaseModel(params=params, num_read_features=dataset.num_read_features, num_info_features=dataset.num_info_features, + ref_sequence_length=dataset.ref_sequence_length, device=utils.gpu_if_available()) + learn_base_model(base_model, dataset, learning_method, training_params, summary_writer=summary_writer) + return base_model + + +def main_without_parsing(args): + params = parse_base_model_params(args) + training_params = parse_training_params(args) + + learning_method = LearningMethod[getattr(args, constants.LEARNING_METHOD_NAME)] + + tarfile_data = getattr(args, constants.TRAIN_TAR_NAME) + pretrained_model_path = getattr(args, constants.PRETRAINED_MODEL_NAME) + pretrained_model = None if pretrained_model_path is None else load_base_model(pretrained_model_path) + tensorboard_dir = getattr(args, constants.TENSORBOARD_DIR_NAME) + summary_writer = SummaryWriter(tensorboard_dir) + dataset = BaseDataset(data_tarfile=tarfile_data, num_folds=10) + + model = train_base_model(params=params, dataset=dataset, training_params=training_params, learning_method=learning_method, + summary_writer=summary_writer, pretrained_model=pretrained_model) + + summary_writer.close() + model.save(getattr(args, constants.OUTPUT_NAME)) + +def parse_arguments(): + parser = argparse.ArgumentParser(description='train the Permutect read set representation model') + add_base_model_params_to_parser(parser) + add_training_params_to_parser(parser) + + parser.add_argument('--' + constants.LEARNING_METHOD_NAME, type=str, required=False, default='SEMISUPERVISED') + parser.add_argument('--' + constants.TRAIN_TAR_NAME, type=str, required=True, + help='tarfile of training/validation datasets produced by preprocess_dataset.py') + parser.add_argument('--' + constants.OUTPUT_NAME, type=str, required=True, help='output saved model file') + parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, + help='output tensorboard directory') + + return parser.parse_args() + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/main/resources/org/broadinstitute/hellbender/tools/permutect/train_model.py b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/train_model.py new file mode 100644 index 00000000000..2eaa1ebc0bc --- /dev/null +++ b/src/main/resources/org/broadinstitute/hellbender/tools/permutect/train_model.py @@ -0,0 +1,125 @@ +import argparse + +import psutil +import torch +from torch.utils.tensorboard import SummaryWriter + +from permutect import constants, utils +from permutect.architecture.artifact_model import ArtifactModel +from permutect.architecture.artifact_spectra import ArtifactSpectra +from permutect.architecture.posterior_model import plot_artifact_spectra +from permutect.architecture.base_model import load_base_model +from permutect.data.base_dataset import BaseDataset +from permutect.data.artifact_dataset import ArtifactDataset +from permutect.data.base_datum import ArtifactDatum +from permutect.parameters import TrainingParameters, add_training_params_to_parser, parse_training_params, \ + ArtifactModelParameters, parse_artifact_model_params, add_artifact_model_params_to_parser +from permutect.utils import Variation, Label + + +def train_artifact_model(hyperparams: ArtifactModelParameters, training_params: TrainingParameters, summary_writer: SummaryWriter, dataset: ArtifactDataset): + model = ArtifactModel(params=hyperparams, num_base_features=dataset.num_base_features, num_ref_alt_features=dataset.num_ref_alt_features, device=utils.gpu_if_available()) + # TODO: magic constant + model.learn(dataset, training_params, summary_writer=summary_writer, epochs_per_evaluation=10) + + for n, var_type in enumerate(Variation): + cal_fig, cal_axes = model.calibration[n].plot_calibration() + summary_writer.add_figure("calibration by count for " + var_type.name, cal_fig) + + return model + + +def learn_artifact_priors_and_spectra(artifact_dataset: ArtifactDataset, genomic_span_of_data: int): + artifact_counts = torch.zeros(len(utils.Variation)) + types_list, depths_list, alt_counts_list = [], [], [] + + artifact_datum: ArtifactDatum + for artifact_datum in artifact_dataset: + if artifact_datum.get_label() != Label.ARTIFACT: + continue + variant_type = artifact_datum.get_variant_type() + artifact_counts[variant_type] += 1 + types_list.append(variant_type) + counts_and_seq_lks = artifact_datum.one_dimensional_data.get_counts_and_seq_lks() + depths_list.append(counts_and_seq_lks.depth) + alt_counts_list.append(counts_and_seq_lks.alt_count) + + # turn the lists into tensors + types_tensor = torch.LongTensor(types_list) + depths_tensor = torch.Tensor(depths_list).float() + alt_counts_tensor = torch.Tensor(alt_counts_list).float() + + log_artifact_priors = torch.log(artifact_counts / genomic_span_of_data) + artifact_spectra = ArtifactSpectra(num_components=2) + + # TODO: hard-coded num epochs!!! + artifact_spectra.fit(num_epochs=10, types_b=types_tensor, depths_1d_tensor=depths_tensor, + alt_counts_1d_tensor=alt_counts_tensor, batch_size=64) + + return log_artifact_priors, artifact_spectra + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='train the Permutect artifact model') + + add_artifact_model_params_to_parser(parser) + add_training_params_to_parser(parser) + + parser.add_argument('--' + constants.LEARN_ARTIFACT_SPECTRA_NAME, action='store_true', + help='flag to include artifact priors and allele fraction spectra in saved output. ' + 'This is worth doing if labeled training data is available but might work poorly ' + 'when Mutect3 generates weak labels based on allele fractions.') + parser.add_argument('--' + constants.GENOMIC_SPAN_NAME, type=float, required=False, + help='Total number of sites considered by Mutect2 in all training data, including those lacking variation or artifacts, hence absent from input datasets. ' + 'Necessary for learning priors since otherwise rates of artifacts and variants would be overinflated. ' + 'Only required if learning artifact log priors') + + # inputs and outputs + parser.add_argument('--' + constants.TRAIN_TAR_NAME, type=str, required=True, + help='tarfile of training/validation datasets produced by preprocess_dataset.py') + parser.add_argument('--' + constants.BASE_MODEL_NAME, type=str, help='Base model from train_base_model.py') + parser.add_argument('--' + constants.OUTPUT_NAME, type=str, required=True, help='path to output saved model file') + parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, + help='path to output tensorboard directory') + + return parser.parse_args() + + +def main_without_parsing(args): + params = parse_artifact_model_params(args) + training_params = parse_training_params(args) + learn_artifact_spectra = getattr(args, constants.LEARN_ARTIFACT_SPECTRA_NAME) + genomic_span = getattr(args, constants.GENOMIC_SPAN_NAME) + + tensorboard_dir = getattr(args, constants.TENSORBOARD_DIR_NAME) + summary_writer = SummaryWriter(tensorboard_dir) + + base_model = load_base_model(getattr(args, constants.BASE_MODEL_NAME)) + print(f"Memory usage percent before creating BaseDataset: {psutil.virtual_memory().percent:.1f}") + base_dataset = BaseDataset(data_tarfile=getattr(args, constants.TRAIN_TAR_NAME), num_folds=10) + print(f"Memory usage percent before creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") + artifact_dataset = ArtifactDataset(base_dataset, + base_model, + base_loader_num_workers=training_params.num_workers, + base_loader_batch_size=training_params.inference_batch_size) + print(f"Memory usage percent after creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") + + model = train_artifact_model(hyperparams=params, training_params=training_params, summary_writer=summary_writer, dataset=artifact_dataset) + print(f"Memory usage percent after training artifact model: {psutil.virtual_memory().percent:.1f}") + + artifact_log_priors, artifact_spectra = learn_artifact_priors_and_spectra(artifact_dataset, genomic_span) if learn_artifact_spectra else (None, None) + if artifact_spectra is not None: + art_spectra_fig, art_spectra_axs = plot_artifact_spectra(artifact_spectra, depth=50) + summary_writer.add_figure("Artifact AF Spectra", art_spectra_fig) + + summary_writer.close() + model.save_with_base_model(base_model, getattr(args, constants.OUTPUT_NAME), artifact_log_priors, artifact_spectra) + + +def main(): + args = parse_arguments() + main_without_parsing(args) + + +if __name__ == '__main__': + main() diff --git a/src/test/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstantsUnitTest.java b/src/test/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstantsUnitTest.java new file mode 100644 index 00000000000..f8a5e7357c5 --- /dev/null +++ b/src/test/java/org/broadinstitute/hellbender/tools/permutect/PermutectArgumentConstantsUnitTest.java @@ -0,0 +1,139 @@ +package org.broadinstitute.hellbender.tools.permutect; + +import org.broadinstitute.barclay.argparser.Argument; +import org.broadinstitute.barclay.argparser.ArgumentCollection; +import org.broadinstitute.barclay.argparser.CommandLineParser; +import org.broadinstitute.hellbender.GATKBaseTest; +import org.broadinstitute.hellbender.cmdline.CommandLineProgram; +import org.broadinstitute.hellbender.testutils.ArgumentsBuilder; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class PermutectArgumentConstantsUnitTest extends GATKBaseTest { + + private class DummyPermutectArgCollection { + @Argument(fullName = PermutectArgumentConstants.NUM_EPOCHS_NAME, doc = "argument from an argument collection", optional = true) + private String Arg3 = null; + } + + private class dummyPermutectWrapper extends CommandLineProgram { + + @Argument(fullName = "dummy-argument",doc = "not in python argument list", optional = true) + private String Arg1 = null; + + // TMP_DIR_NAME = "tmp_dir" // this is a representative inhereited argument that is present in the python argument list + + @Argument(fullName = PermutectArgumentConstants.OUTPUT_NAME, doc = "a standard permutect argument", optional = false) + private String Arg2 = null; + + @Argument(fullName = PermutectArgumentConstants.INFO_LAYERS_NAME, doc = "in python argument list", optional = true) + private String Arg3 = null; + + @Argument(fullName = PermutectArgumentConstants.BASE_MODEL_NAME, doc = "in python argument list, has GATK defined default value is overwritten", optional = true) + private String Arg4 = "THIS_SHOULD_NOT_BE_HERE"; + + @Argument(fullName = PermutectArgumentConstants.BATCH_SIZE_NAME, doc = "in python argument list, has GATK defined default value, but is not specified on the cli", optional = true) + private String Arg4b = "THIS_SHOULD_NOT_BE_HERE"; + + @Argument(fullName = PermutectArgumentConstants.DROPOUT_P_NAME, doc = "flag argument", optional = true) + private boolean Arg5 = false; + + @Argument(fullName = PermutectArgumentConstants.NUM_READ_FEATURES_NAME, doc = "integer arguments", optional = true) + private int Arg6 = 3; + + @Argument(fullName = PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, doc = "list argument, optional", optional = true) + private List Arg7 = new ArrayList<>(); + + @ArgumentCollection + DummyPermutectArgCollection args = new DummyPermutectArgCollection(); + + @Override + protected Object doWork() { return null; } + } + + @Test + public void testGetPtyhonClassArgumentsFromToolParser() { + ArgumentsBuilder builder = new ArgumentsBuilder(); + builder.add(PermutectArgumentConstants.OUTPUT_NAME, "output"); + builder.add("dummy-argument", "THIS_SHOULD_NOT_BE_HERE"); + builder.add(PermutectArgumentConstants.INFO_LAYERS_NAME, "info_layers"); + builder.add(PermutectArgumentConstants.BASE_MODEL_NAME, "base_model"); + builder.addFlag(PermutectArgumentConstants.DROPOUT_P_NAME); + builder.add(PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, "agg1"); + builder.add(PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, "agg2"); + builder.add(PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, "agg3"); + builder.add(PermutectArgumentConstants.AGGREGATION_LAYERS_NAME, "agg4"); + builder.add(PermutectArgumentConstants.NUM_READ_FEATURES_NAME, "2"); + builder.add(PermutectArgumentConstants.NUM_EPOCHS_NAME, "num_epochs"); + CommandLineParser parser = new dummyPermutectWrapper().getCommandLineParser(); + final boolean conversionMap = parser.parseArguments(new PrintStream(System.err), builder.getArgsArray()); + + List pyArgs = PermutectArgumentConstants.getPtyhonClassArgumentsFromToolParser(parser); + Assert.assertTrue(pyArgs.contains("--output")); + Assert.assertTrue(pyArgs.contains("output")); + Assert.assertEquals(pyArgs.indexOf("output") - 1, pyArgs.indexOf("--output")); + + Assert.assertTrue(pyArgs.contains("--info_layers")); + Assert.assertTrue(pyArgs.contains("info_layers")); + Assert.assertEquals(pyArgs.indexOf("info_layers") - 1, pyArgs.indexOf("--info_layers")); + + Assert.assertTrue(pyArgs.contains("--dropout_p")); + + Assert.assertTrue(pyArgs.contains("--aggregation_layers")); + Assert.assertTrue(pyArgs.contains("agg1")); + Assert.assertTrue(pyArgs.contains("agg2")); + Assert.assertTrue(pyArgs.contains("agg3")); + Assert.assertTrue(pyArgs.contains("agg4")); + Assert.assertEquals(pyArgs.indexOf("agg1") - 1, pyArgs.indexOf("--aggregation_layers")); + Assert.assertEquals(pyArgs.indexOf("agg2") - 1, pyArgs.indexOf("agg1")); + Assert.assertEquals(pyArgs.indexOf("agg3") - 1, pyArgs.indexOf("agg2")); + Assert.assertEquals(pyArgs.indexOf("agg4") - 1, pyArgs.indexOf("agg3")); + + Assert.assertTrue(pyArgs.contains("--num_read_features")); + Assert.assertTrue(pyArgs.contains("2")); + + Assert.assertTrue(pyArgs.contains("--num_epochs")); + Assert.assertTrue(pyArgs.contains("num_epochs")); + Assert.assertEquals(pyArgs.indexOf("num_epochs") - 1, pyArgs.indexOf("--num_epochs")); + + Assert.assertFalse(pyArgs.contains("--dummy-argument")); + Assert.assertFalse(pyArgs.contains("THIS_SHOULD_NOT_BE_HERE")); + + Assert.assertTrue(pyArgs.contains("--base_model")); + Assert.assertTrue(pyArgs.contains("base_model")); + Assert.assertEquals(pyArgs.indexOf("base_model") - 1, pyArgs.indexOf("--base_model")); + + Assert.assertFalse(pyArgs.contains("--tmp_dir")); + Assert.assertFalse(pyArgs.contains("tmp_dir")); + + Assert.assertFalse(pyArgs.contains("--batch_size")); + } + + @Test + public void testGenerateArgumentMap() { + final Map conversionMap = PermutectArgumentConstants.PERMUTECT_PYTHON_ARGUMENT_MAP; + + Assert.assertNotNull(conversionMap); + Assert.assertTrue(conversionMap.entrySet().size() > 30); // assert that the map is not empty and that it reflectively picked up a lot of arguments, the exact number will be subject to change + + for (Map.Entry entry : conversionMap.entrySet()) { + Assert.assertNotNull(entry.getKey()); + Assert.assertFalse(entry.getKey().contains("_")); + + // ptyhon arguments should not contain hyphens + Assert.assertNotNull(entry.getValue()); + Assert.assertFalse(entry.getValue().contains("-")); + } + + // various illegal fields that could have snuck into the reflection by acciedent that we want to make sure didn't + Assert.assertFalse(conversionMap.containsKey("PERMUTECT_PYTHON_ARGUMENT_MAP")); + Assert.assertFalse(conversionMap.containsKey("dragen-mode")); + Assert.assertFalse(conversionMap.containsKey("getPythonClassArgumentsFromToolParser")); + Assert.assertFalse(conversionMap.containsKey("serialVersionUID")); + } +} \ No newline at end of file