Skip to content

Commit

Permalink
#494 - Insert data fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
davmlaw committed Sep 30, 2021
1 parent c66fd3e commit 7a42077
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions genes/management/commands/import_gene_annotation2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from genes.models import GeneSymbol, GeneAnnotationImport, Gene, GeneVersion, TranscriptVersion, Transcript
from genes.models_enums import AnnotationConsortium
from library.file_utils import open_handle_gzip
from library.utils import invert_dict
from snpdb.models.models_genome import GenomeBuild


Expand All @@ -28,6 +29,12 @@ def add_arguments(self, parser):
group.add_argument('--merged-json', help='Merged JSON (from multiple PyReference files)')

def handle(self, *args, **options):
build_name = options["genome_build"]
annotation_consortium_name = options["annotation_consortium"]
genome_build = GenomeBuild.get_name_or_alias(build_name)
ac_dict = invert_dict(dict(AnnotationConsortium.choices))
annotation_consortium = ac_dict[annotation_consortium_name]

if pyreference_json := options["pyreference_json"]:
pyreference_data = []
for prj_filename in pyreference_json:
Expand All @@ -47,7 +54,7 @@ def handle(self, *args, **options):
else:
raise ValueError("You need to specify at least one of '--pyreference-json' or '--merged-json'")

self._import_merged_data(merged_data)
self._import_merged_data(genome_build, annotation_consortium, merged_data)

@staticmethod
def _get_most_recent_transcripts(pyreference_data) -> List[Set]:
Expand Down Expand Up @@ -96,7 +103,7 @@ def _convert_to_merged_data(self, pyreference_data: List[Dict]) -> List[Dict]:
data = {
"gene_annotation_import": prd["reference_gtf"],
"gene_version": gene_version,
"transcript_versions": transcript_versions,
"transcript_version": transcript_versions,
}
merged_data.append(data)

Expand All @@ -108,20 +115,22 @@ def _import_merged_data(self, genome_build: GenomeBuild, annotation_consortium,

known_gene_symbols = set(GeneSymbol.objects.all().values_list("pk", flat=True))
genes_qs = Gene.objects.filter(annotation_consortium=annotation_consortium)
known_genes_ids = {genes_qs.values_list("identifier", flat=True)}
known_genes_ids = set(genes_qs.values_list("identifier", flat=True))
transcripts_qs = Transcript.objects.filter(annotation_consortium=annotation_consortium)
known_transcript_ids = {transcripts_qs.values_list("identifier", flat=True)}
known_transcript_ids = set(transcripts_qs.values_list("identifier", flat=True))

gene_version_qs = GeneVersion.objects.filter(genome_build=genome_build,
annotation_consortium=annotation_consortium)
gene__annotation_consortium=annotation_consortium)
known_gene_version_ids_by_accession = {f"{gene_id}.{version}": pk for (pk, gene_id, version) in gene_version_qs.values_list("pk", "gene_id", "version")}
transcript_version_qs = TranscriptVersion.objects.filter(genome_build=genome_build,
annotation_consortium=annotation_consortium)
known_transcript_version_ids_by_accession = {f"{transcript_id}.{version}": pk for (pk, transcript_id, version) in transcript_version_qs.values_list("pk", "transcript_id", "version")}
transcript__annotation_consortium=annotation_consortium)
tv_values = transcript_version_qs.values_list("pk", "transcript_id", "version")
known_transcript_version_ids_by_accession = {f"{transcript_id}.{version}": pk
for (pk, transcript_id, version) in tv_values}

for data in merged_data:
import_data = data["gene_annotation_import"]
logging.info("%s has %d transcripts", import_data, len(data["transcript_versions"]))
logging.info("%s has %d transcripts", import_data, len(data["transcript_version"]))
import_source = GeneAnnotationImport.objects.create(annotation_consortium=annotation_consortium,
genome_build=genome_build,
filename=import_data["path"],
Expand Down Expand Up @@ -162,17 +171,19 @@ def _import_merged_data(self, genome_build: GenomeBuild, annotation_consortium,
if new_gene_symbols:
logging.info("Creating %d new gene symbols", len(new_gene_symbols))
GeneSymbol.objects.bulk_create(new_gene_symbols, batch_size=self.BATCH_SIZE)
known_gene_symbols.update({gene_symbol.symbol for gene_symbol in new_gene_symbols})

if new_genes:
logging.info("Creating %d new genes", len(new_genes))
Gene.objects.bulk_create(new_genes, batch_size=self.BATCH_SIZE)
# Update with newly inserted records - so that we have a PK to use below
known_genes_ids.update({gene.identifier for gene in new_genes})

if new_gene_versions:
logging.info("Creating %d new gene versions", len(new_gene_versions))
GeneVersion.objects.bulk_create(new_gene_versions, batch_size=self.BATCH_SIZE)

# Update with newly inserted gene versions - so that we have a PK to use below
known_gene_version_ids_by_accession.update({f"{gv.gene_id}.{gv.version}" for gv in new_gene_versions})
# Update with newly inserted records - so that we have a PK to use below
known_gene_version_ids_by_accession.update({f"{gv.gene_id}.{gv.version}" for gv in new_gene_versions})

# Could potentially be duplicate gene versions (diff transcript versions from diff GFFs w/same GeneVersion)
if modified_gene_versions:
Expand All @@ -181,15 +192,14 @@ def _import_merged_data(self, genome_build: GenomeBuild, annotation_consortium,
["gene_symbol_id", "hgnc_id", "description", "biotype", "import_source"],
batch_size=self.BATCH_SIZE)

new_transcripts = []
new_transcript_ids = set()
new_transcript_versions = []
modified_transcript_versions = []

for transcript_accession, tv_data in data["transcript_version"].items():
transcript_id, version = TranscriptVersion.get_transcript_id_and_version(transcript_accession)
if transcript_id not in known_transcript_ids:
new_transcripts.append(Transcript(identifier=transcript_id,
annotation_consortium=annotation_consortium))
new_transcript_ids.add(transcript_id)

gene_version_id = known_gene_version_ids_by_accession[tv_data["gene_version"]]
transcript_version = TranscriptVersion(transcript_id=transcript_id,
Expand All @@ -205,10 +215,15 @@ def _import_merged_data(self, genome_build: GenomeBuild, annotation_consortium,
else:
new_transcript_versions.append(transcript_version)

if new_transcripts:
logging.info("Creating %d new transcripts", len(new_transcripts))
if new_transcript_ids:
logging.info("Creating %d new transcripts", len(new_transcript_ids))
new_transcripts = [Transcript(identifier=transcript_id, annotation_consortium=annotation_consortium)
for transcript_id in new_transcript_ids]

Transcript.objects.bulk_create(new_transcripts, batch_size=self.BATCH_SIZE)
known_transcript_ids.update(new_transcript_ids)

# No need to update known after insert as there won't be duplicate transcript versions in the merged data
if new_transcript_versions:
logging.info("Creating %d new transcript versions", len(new_transcript_versions))
TranscriptVersion.objects.bulk_create(new_transcript_versions, batch_size=self.BATCH_SIZE)
Expand Down

0 comments on commit 7a42077

Please sign in to comment.