diff --git a/genes/management/commands/import_gene_annotation2.py b/genes/management/commands/import_gene_annotation2.py index 2e9597b7e..49192ffcc 100644 --- a/genes/management/commands/import_gene_annotation2.py +++ b/genes/management/commands/import_gene_annotation2.py @@ -7,6 +7,7 @@ from genes.models import GeneSymbol, GeneAnnotationImport, Gene, GeneVersion, TranscriptVersion, Transcript from genes.models_enums import AnnotationConsortium from library.file_utils import open_handle_gzip +from library.utils import invert_dict from snpdb.models.models_genome import GenomeBuild @@ -28,6 +29,12 @@ def add_arguments(self, parser): group.add_argument('--merged-json', help='Merged JSON (from multiple PyReference files)') def handle(self, *args, **options): + build_name = options["genome_build"] + annotation_consortium_name = options["annotation_consortium"] + genome_build = GenomeBuild.get_name_or_alias(build_name) + ac_dict = invert_dict(dict(AnnotationConsortium.choices)) + annotation_consortium = ac_dict[annotation_consortium_name] + if pyreference_json := options["pyreference_json"]: pyreference_data = [] for prj_filename in pyreference_json: @@ -47,7 +54,7 @@ def handle(self, *args, **options): else: raise ValueError("You need to specify at least one of '--pyreference-json' or '--merged-json'") - self._import_merged_data(merged_data) + self._import_merged_data(genome_build, annotation_consortium, merged_data) @staticmethod def _get_most_recent_transcripts(pyreference_data) -> List[Set]: @@ -96,7 +103,7 @@ def _convert_to_merged_data(self, pyreference_data: List[Dict]) -> List[Dict]: data = { "gene_annotation_import": prd["reference_gtf"], "gene_version": gene_version, - "transcript_versions": transcript_versions, + "transcript_version": transcript_versions, } merged_data.append(data) @@ -108,20 +115,22 @@ def _import_merged_data(self, genome_build: GenomeBuild, annotation_consortium, known_gene_symbols = set(GeneSymbol.objects.all().values_list("pk", flat=True)) genes_qs = Gene.objects.filter(annotation_consortium=annotation_consortium) - known_genes_ids = {genes_qs.values_list("identifier", flat=True)} + known_genes_ids = set(genes_qs.values_list("identifier", flat=True)) transcripts_qs = Transcript.objects.filter(annotation_consortium=annotation_consortium) - known_transcript_ids = {transcripts_qs.values_list("identifier", flat=True)} + known_transcript_ids = set(transcripts_qs.values_list("identifier", flat=True)) gene_version_qs = GeneVersion.objects.filter(genome_build=genome_build, - annotation_consortium=annotation_consortium) + gene__annotation_consortium=annotation_consortium) known_gene_version_ids_by_accession = {f"{gene_id}.{version}": pk for (pk, gene_id, version) in gene_version_qs.values_list("pk", "gene_id", "version")} transcript_version_qs = TranscriptVersion.objects.filter(genome_build=genome_build, - annotation_consortium=annotation_consortium) - known_transcript_version_ids_by_accession = {f"{transcript_id}.{version}": pk for (pk, transcript_id, version) in transcript_version_qs.values_list("pk", "transcript_id", "version")} + transcript__annotation_consortium=annotation_consortium) + tv_values = transcript_version_qs.values_list("pk", "transcript_id", "version") + known_transcript_version_ids_by_accession = {f"{transcript_id}.{version}": pk + for (pk, transcript_id, version) in tv_values} for data in merged_data: import_data = data["gene_annotation_import"] - logging.info("%s has %d transcripts", import_data, len(data["transcript_versions"])) + logging.info("%s has %d transcripts", import_data, len(data["transcript_version"])) import_source = GeneAnnotationImport.objects.create(annotation_consortium=annotation_consortium, genome_build=genome_build, filename=import_data["path"], @@ -162,17 +171,19 @@ def _import_merged_data(self, genome_build: GenomeBuild, annotation_consortium, if new_gene_symbols: logging.info("Creating %d new gene symbols", len(new_gene_symbols)) GeneSymbol.objects.bulk_create(new_gene_symbols, batch_size=self.BATCH_SIZE) + known_gene_symbols.update({gene_symbol.symbol for gene_symbol in new_gene_symbols}) if new_genes: logging.info("Creating %d new genes", len(new_genes)) Gene.objects.bulk_create(new_genes, batch_size=self.BATCH_SIZE) + # Update with newly inserted records - so that we have a PK to use below + known_genes_ids.update({gene.identifier for gene in new_genes}) if new_gene_versions: logging.info("Creating %d new gene versions", len(new_gene_versions)) GeneVersion.objects.bulk_create(new_gene_versions, batch_size=self.BATCH_SIZE) - - # Update with newly inserted gene versions - so that we have a PK to use below - known_gene_version_ids_by_accession.update({f"{gv.gene_id}.{gv.version}" for gv in new_gene_versions}) + # Update with newly inserted records - so that we have a PK to use below + known_gene_version_ids_by_accession.update({f"{gv.gene_id}.{gv.version}" for gv in new_gene_versions}) # Could potentially be duplicate gene versions (diff transcript versions from diff GFFs w/same GeneVersion) if modified_gene_versions: @@ -181,15 +192,14 @@ def _import_merged_data(self, genome_build: GenomeBuild, annotation_consortium, ["gene_symbol_id", "hgnc_id", "description", "biotype", "import_source"], batch_size=self.BATCH_SIZE) - new_transcripts = [] + new_transcript_ids = set() new_transcript_versions = [] modified_transcript_versions = [] for transcript_accession, tv_data in data["transcript_version"].items(): transcript_id, version = TranscriptVersion.get_transcript_id_and_version(transcript_accession) if transcript_id not in known_transcript_ids: - new_transcripts.append(Transcript(identifier=transcript_id, - annotation_consortium=annotation_consortium)) + new_transcript_ids.add(transcript_id) gene_version_id = known_gene_version_ids_by_accession[tv_data["gene_version"]] transcript_version = TranscriptVersion(transcript_id=transcript_id, @@ -205,10 +215,15 @@ def _import_merged_data(self, genome_build: GenomeBuild, annotation_consortium, else: new_transcript_versions.append(transcript_version) - if new_transcripts: - logging.info("Creating %d new transcripts", len(new_transcripts)) + if new_transcript_ids: + logging.info("Creating %d new transcripts", len(new_transcript_ids)) + new_transcripts = [Transcript(identifier=transcript_id, annotation_consortium=annotation_consortium) + for transcript_id in new_transcript_ids] + Transcript.objects.bulk_create(new_transcripts, batch_size=self.BATCH_SIZE) + known_transcript_ids.update(new_transcript_ids) + # No need to update known after insert as there won't be duplicate transcript versions in the merged data if new_transcript_versions: logging.info("Creating %d new transcript versions", len(new_transcript_versions)) TranscriptVersion.objects.bulk_create(new_transcript_versions, batch_size=self.BATCH_SIZE)