diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml index 9bbf3ba..9f90823 100644 --- a/.github/workflows/dependency-review.yml +++ b/.github/workflows/dependency-review.yml @@ -10,14 +10,9 @@ name: 'Dependency review' on: pull_request: - branches: [ "master" ] + branches: + - master # Triggers only on pull requests targeting the master branch -# If using a dependency submission action in this workflow this permission will need to be set to: -# -# permissions: -# contents: write -# -# https://docs.github.com/en/enterprise-cloud@latest/code-security/supply-chain-security/understanding-your-software-supply-chain/using-the-dependency-submission-api permissions: contents: read # Write permissions for pull-requests are required for using the `comment-summary-in-pr` option, comment out if you aren't using this option @@ -36,4 +31,4 @@ jobs: comment-summary-in-pr: always # fail-on-severity: moderate # deny-licenses: GPL-1.0-or-later, LGPL-2.0-or-later - # retry-on-snapshot-warnings: true + # retry-on-snapshot-warnings: true \ No newline at end of file diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..89485b6 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,53 @@ +# This is a GitHub Actions workflow file for linting Python code using Tox and Flake8. +name: Lint + +on: + push: + branches: + - '**' # run on every push to any branch + pull_request: + branches: + - master # run on pull requests targeting the master branch + +permissions: + contents: write + +jobs: + lint: + runs-on: ubuntu-latest + steps: + # Step 1: Checkout the repository + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch the full history to allow branch checkout + + # Step 2: Determine the branch name + - name: Get branch name + id: vars + run: | + if [ "${{ github.event_name }}" = "pull_request" ]; then + echo "BRANCH_NAME=${{ github.head_ref }}" >> $GITHUB_ENV + else + echo "BRANCH_NAME=$(echo ${GITHUB_REF#refs/heads/})" >> $GITHUB_ENV + fi + + # Step 3: Install dependencies + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox + + # Step 4: Run Tox for Linting + - name: Run Tox Lint + run: tox -e flake8 + + # Step 5: Commit and push changes if Black reformats files + - name: Commit and push changes + run: | + git config --global user.name "GitHub Actions" + git config --global user.email "actions@github.com" + git checkout $BRANCH_NAME # Check out the branch + git add . + git commit -m "Apply linting fixes" || echo "No changes to commit" + git push origin $BRANCH_NAME \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 452da05..ff35b58 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,7 +10,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest] - python-version: ['3.7', '3.8', '3.9', '3.10'] + python-version: ['3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v2 diff --git a/.readthedocs.yaml b/.readthedocs.yaml index fb523be..18eee5a 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,6 +8,6 @@ sphinx: # Explicitly set the version of Python python: - version: 3.8 + version: 3.12 install: - requirements: docs/requirements.txt \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b1850f..f36ad4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,24 +1,40 @@ # Change Log ## TODO: ideas, issues and planed extensions or changes that are not yet implemented + * optimize add_qc_metrics for run after new samples have been added - should not recompute everything -* planned new feature: during import of long reads, (optionally) correct for short exon alignment issues. +* planned new feature: during import of long reads, (optionally) correct for short exon alignment issues. * separate new read import and classification of isoforms. +## [2.0.0] + +* support the analysis of Oxford Nanopore data +* new option of TSS identification according to the reference annotation +* new gene model characteristics, simplex coordinates and relative entropy +* new visualisation using triangle plot +* fix bugs in external gtf import to reconstruct transcriptome +* improve the coordination analysis of TSS and PAS +* support the import of SQANTI QC report and filtering based on it +* support the export of positive and negative TSS for ML +* support proteogenomic approaches at the interface of transcriptomics and proteomics +* improve code readability and filter tag construction + ## [0.3.5] + * fixed a bug in domain plots, which was introduced in 0.3.4 * fixed a bug in iter_genes/iter_transcripts with region='chr', and no positions specified * new option in plot_domains to depict noncoding transcripts, controlled with the coding_only parameter ## [0.3.4] + * fixing #8: AssertationError when unifying TSS/PAS between transcript -* improved domain plots: ORF start and end do not appear like exon exon boundaries. -* API change: separated ORF prediction from QC metrics calculation. +* improved domain plots: ORF start and end do not appear like exon exon boundaries. +* API change: separated ORF prediction from QC metrics calculation. * new feature: count number of upstream start codons in Gene.add_orfs() (called by default when adding QC metrics to transcriptome) -* new feature: calculate Fickett testcode and hexamer score for longest ORFs, to separate coding and noncoding genes. - +* new feature: calculate Fickett testcode and hexamer score for longest ORFs, to separate coding and noncoding genes. ## [0.3.3] + * fixed bug in filter_ref_transcripts with no query * export gtf with long read transcripts as well uncovered as reference transcripts * fix warning in plot_diff_results @@ -28,11 +44,13 @@ * improved documentation: syntax highlighting, code style, additional explanations on filtering ## [0.3.2] + * restructured tutorials * new feature: add domains to differential splicing result tables. * new feature: min_coverage and max_coverage for iter_genes function. ## [0.3.1] + * new feature: add protein domains from 3 different sources and depict them with Gene.plot_domains() * new feature: restrict gene and transcript iterators on list of genes of interest * new feature: filter_transcripts function for genes @@ -43,39 +61,45 @@ * order of events is now according to gene strand: A upstream of B ## [0.3.0] + * new feature: find longest ORF and infer NMD of lr transcripts (and annotation) * new feature: allow for several TSS/PAS per intron chain and unify them across intron chains * changed default parameter of filter_query in run_isotools script to "FSM or not (INTERNAL_PRIMING or RTTS)" ## [0.2.11.1] -* bugfix: KeyError during transcriptome reconstruction in _add_chimeric. + +* bugfix: KeyError during transcriptome reconstruction in _add_chimeric. * bugfix: default colors in plot_diff_results. ## [0.2.11] + * added function to import samples from csv/gtf to import transcriptome reconstruction / quantification from other tools. * dropped requirement for gtf files to be tabix indexed. ## [0.2.10] + * fixed get_overlap - important for correct assignment of mono exonic genes to reference * added parameter to control for minimal mapping quality in add_sample_from_bam. This allows for filtering out ambiguous reads, which have mapping quality of 0 * fixed plot_diff_result (Key error due to incorrect parsing of group names) -* New function estimate_tpm_threshold, to estimate the minimal abundance level of observable transcripts, given a sequencing depth. -* New function coordination_test, to test coordination of splicing events within a gene. +* New function estimate_tpm_threshold, to estimate the minimal abundance level of observable transcripts, given a sequencing depth. +* New function coordination_test, to test coordination of splicing events within a gene. * Optional log or linear scale for the coverage axis in sashimi plots. ## [0.2.9] + * added DIE test * adjusted classification of novel exonic TSS/PAS to ISM -* improved assignment of reference genes in case of equal number of matching splice sites to several reference genes. +* improved assignment of reference genes in case of equal number of matching splice sites to several reference genes. * added parameter to control for minimal exonic overlap to reference genes in add_sample_from_bam. * changed computation of direct repeats. Added wobble and max_mm parameters. -* exposed parameters to end user in the add_qc_metrics function. +* exposed parameters to end user in the add_qc_metrics function. * added options for additional fields in gtf output * improved options for graphical output with the command line script * fixed plot_bar default color scheme ## [0.2.8] + * fix: version information lost when pickeling reference. * fix missing gene name * added pt_size parameter to plot_embedding and plot_diff_results function @@ -84,11 +108,13 @@ ## [0.2.7] + * added command line script run_isotools.py -* added test data for unit tests +* added test data for unit tests ## [0.2.6] + * Added unit tests * Fixed bug in novel splicing subcategory assignment * new feature: rarefaction analysis @@ -98,41 +124,48 @@ * added optional progress bar to iter_genes/transcripts ## [0.2.5] + * New feature: distinguish noncanonical and canonical novel splice sites for direct repeat hist * New feature: option to drop partially aligned reads with the min_align_fraction parameter in add_sample_from_bam ## [0.2.4] + * New feature: added option to save read names during bam import * new feature: gzip compressed gtf output ## [0.2.3] + * Changed assignment of transcripts to genes if no splice sites match. * Fix: more flexible import of reference files, gene name not required (but id is), introducing "infer_genes" from exon entries of gtf files. * New function: Transcriptome.remove_filter(filter=[tags]) ## [0.2.2] + * Fix: export to gtf with filter features ## [0.2.1] + * Fix: import reference from gtf file * New feature: Import multiple samples from single bam tagged by barcode (e.g. from single cell data) * Fix: issue with zero base exons after shifting fuzzy junctions - ## [0.2.0] + * restructure to meet PyPI recommendations * New feature: isoseq.altsplice_test accepts more than 2 groups, and computes ML parameters for all groups ## [0.1.5] + * New feature: restrict tests on provided splice_types * New feature: provide position to find given alternative splicing events ## [0.1.4] + * Fix: Issue with noncanonical splicing detection introduced in 0.1.3 * Fix: crash with secondary alignments in bam files during import. * New feature: Report and skip if alignment outside chromosome (uLTRA issue) * Fix: import of chimeric reads (secondary alignments have no SA tag) -* Fix: Transcripts per sample in sample table: During import count only used transcripts, do not count chimeric transcripts twice. +* Fix: Transcripts per sample in sample table: During import count only used transcripts, do not count chimeric transcripts twice. * Change: sample_table reports chimeric_reads and nonchimeric_reads (instead of total_reads) * Change: import of long read bam is more verbose in info mode * Fix: Bug: import of chained chimeric alignments overwrites read coverage when merging to existing transcript @@ -140,13 +173,14 @@ * Change: refactored add_biases to add_qc_metrics * fix: property of transcripts included {sample_name:0} * save the TSS and PAS positions -* New: use_satag parameter for add_sample_from_bam +* New: use_satag parameter for add_sample_from_bam * Change: use median TSS/PAS (of all reads with same splice pattern) as transcript start/end (e.g. exons[0][0]/exons[-1][1]) * Fix: Novel exon skipping annotation now finds all exonic regions that are skipped. * change: Default filter of FRAGMENTS now only tags reads that do not use a reference TSS or PAS + ## [0.1.3] -* Fix: improved performance of noncanonical splicing detection by avoiding redundant lookups. +* Fix: improved performance of noncanonical splicing detection by avoiding redundant lookups. ## [0.1.2] - 2020-05-03 @@ -157,7 +191,6 @@ * New: Do not distinguish intronic/exonic novel splice sites. Report distance to shortest splice site of same type. * Fix: Sashimi plots ignored mono exons - ## [0.1.1] - 2020-04-12 * Fix: fixed bug in TSS/PAS events affecting start/end positions and known flag. @@ -170,23 +203,23 @@ * moved examples in documentation ## [0.0.2] - 2020-03-22 + * Change: refactored SpliceGraph to SegmentGraph to better comply with common terms in literature -* New: added a basic implementation of an actual SpliceGraph (as commonly defined in literature) +* New: added a basic implementation of an actual SpliceGraph (as commonly defined in literature) * based on sorted dict * not used so far, but maybe useful in importing the long read bam files since it can be extended easily -* New: added decorators "experimental" and "deprecated" to mark unsafe functions +* New: added decorators "experimental" and "deprecated" to mark unsafe functions * Change: in differential splicing changed the alternative fraction, to match the common PSI (% spliced in) definition * Change: narrowed definition of mutually exclusive exons: the alternatives now need to to feature exactly one ME exon and rejoin at node C * Change: for ME exons now the beginning of node C is returned as "end" of the splice bubble -* New: differential splicing result contains "novel", indicating that the the alternative is in the annotation +* New: differential splicing result contains "novel", indicating that the the alternative is in the annotation * New: added alternative TSS/alternative PAS to the differential splicing test * Change: removed obsolete weights from splice graph and added strand * Change: unified parameters and column names of results of Transcriptome.find_splice_bubbles() and Transcriptome.altsplice_test() -* Fix: add_short_read_coverage broken if short reads are already there. - +* Fix: add_short_read_coverage broken if short reads are already there. ## [0.0.1] - 2020-02-25 + * first shared version * New: added option to export alternative splicing events for MISO and rMATS * New: added change log - diff --git a/README.md b/README.md index 78a3bc7..e7ac775 100644 --- a/README.md +++ b/README.md @@ -62,5 +62,7 @@ isoseq.save('../tests/data/example_1_isotools.pkl') ## Citation and feedback: * If you run into any issues, please use the [github issues report feature](https://github.com/HerwigLab/IsoTools2/issues). -* For general feedback, please write me an email to [lienhard@molgen.mpg.de](mailto:lienhard@molgen.mpg.de). -* If you use IsoTools in your publication, please cite the following [paper](https://doi.org/10.1093/bioinformatics/btad364): Lienhard et al, Bioinformatics, 2023: IsoTools: a flexible workflow for long-read transcriptome sequencing analysis +* For general feedback, please write us an email to [yalan_bi@molgen.mpg.de](mailto:yalan_bi@molgen.mpg.de) and [herwig@molgen.mpg.de](mailto:herwig@molgen.mpg.de). +* If you use IsoTools in your publication, please cite the following paper in addition to this repository: + * Lienhard, Matthias et al. “IsoTools: a flexible workflow for long-read transcriptome sequencing analysis.” Bioinformatics (Oxford, England) vol. 39,6 (2023): btad364. [doi:10.1093/bioinformatics/btad364](https://doi.org/10.1093/bioinformatics/btad364) + * Bi, Yalan et al. “IsoTools 2.0: Software for Comprehensive Analysis of Long-read Transcriptome Sequencing Data.” Journal of molecular biology, 169049. 26 Feb. 2025, [doi:10.1016/j.jmb.2025.169049](https://doi.org/10.1016/j.jmb.2025.169049) diff --git a/VERSION.txt b/VERSION.txt index 7bfef18..227cea2 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.3.5_rc11 +2.0.0 diff --git a/docs/conf.py b/docs/conf.py index c8c319c..d06c8d5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,20 +12,23 @@ # import os import sys -with open(os.path.join('..', 'VERSION.txt'), 'r') as versionfile: # version string is the first line of this file + +with open( + os.path.join("..", "VERSION.txt"), "r" +) as versionfile: # version string is the first line of this file __version__ = versionfile.read().strip() # Location of source files -sys.path.insert(0, os.path.abspath('../src')) +sys.path.insert(0, os.path.abspath("../src")) # -- Project information ----------------------------------------------------- -project = 'isotools' -copyright = '2021, Matthias Lienhard' -author = 'Matthias Lienhard' +project = "isotools" +copyright = "2021, Matthias Lienhard" +author = "Matthias Lienhard" # The short X.Y version -version = '.'.join(__version__.split('.')[:2]) +version = ".".join(__version__.split(".")[:2]) # The full version, including alpha/beta/rc tags release = __version__ @@ -37,29 +40,29 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinxarg.ext', - 'sphinx.ext.viewcode', - 'sphinx.ext.todo', + "sphinx.ext.autodoc", + "sphinxarg.ext", + "sphinx.ext.viewcode", + "sphinx.ext.todo", "sphinx_rtd_theme", - 'myst_parser', - 'nbsphinx' + "myst_parser", + "nbsphinx", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = 'en' +language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -67,12 +70,12 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # -- Extension configuration ------------------------------------------------- @@ -83,4 +86,4 @@ # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True -autodoc_member_order = 'bysource' +autodoc_member_order = "bysource" diff --git a/docs/notebooks/01_prepare_data.ipynb b/docs/notebooks/01_prepare_data.ipynb index 4ea414a..e6702f0 100644 --- a/docs/notebooks/01_prepare_data.ipynb +++ b/docs/notebooks/01_prepare_data.ipynb @@ -420,7 +420,7 @@ "source": [ "## Demonstration Data\n", "\n", - "To create an demonstration data set, we aligned the ENCODE fastq files with minimap2, and sub-selected reads mapping to chromosome 8 only. All resulting files (~270 Mb) [can be downloaded here](https://nc.molgen.mpg.de/cloud/index.php/s/zYe7g6qnyxGDxRd)." + "To create an demonstration data set, we aligned the ENCODE fastq files with minimap2, and sub-selected reads mapping to chromosome 8 only. All resulting files (~270 Mb) [can be downloaded here](https://nc.molgen.mpg.de/cloud/index.php/s/Mf2zMePGBzFWFk8)." ] }, { diff --git a/setup.cfg b/setup.cfg index 277d576..764cd13 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,15 +1,15 @@ [metadata] name = isotools version = file: VERSION.txt -author = Matthias Lienhard -author_email = lienhard@molgen.mpg.de +author = Matthias Lienhard, Yalan Bi +author_email = lienhard@molgen.mpg.de, yalan_bi@molgen.mpg.de description = Framework for the analysis of long read transcriptome sequencing data long_description = file: README.md long_description_content_type = text/markdown license_files = LICENSE.txt -url = https://github.com/MatthiasLienhard/isotools +url = https://github.com/HerwigLab/IsoTools2 project_urls = - Bug Tracker = https://github.com/MatthiasLienhard/isotools/issues + Bug Tracker = https://github.com/HerwigLab/IsoTools2/issues classifiers = Programming Language :: Python :: 3 License :: OSI Approved :: MIT License @@ -19,7 +19,7 @@ classifiers = package_dir = = src packages = find: -python_requires = >=3.6 +python_requires = >=3.10 install_requires = numpy pandas @@ -54,4 +54,5 @@ testing= tox>=3.24 [flake8] -max-line-length=160 +max-line-length = 176 +extend-ignore = E203,E701 diff --git a/src/isotools/__init__.py b/src/isotools/__init__.py index 4e91883..af8154d 100644 --- a/src/isotools/__init__.py +++ b/src/isotools/__init__.py @@ -22,19 +22,32 @@ Controlled vocabulary for filtering by novel alternative splicing. """ - try: from importlib.metadata import distribution except ModuleNotFoundError: from importlib_metadata import distribution # py3.7 -__version__ = distribution('isotools').version +__version__ = distribution("isotools").version from .gene import Gene from .transcriptome import Transcriptome from .splice_graph import SegmentGraph, SegGraphNode from ._transcriptome_stats import estimate_tpm_threshold -from ._transcriptome_filter import DEFAULT_GENE_FILTER, DEFAULT_TRANSCRIPT_FILTER, DEFAULT_REF_TRANSCRIPT_FILTER, ANNOTATION_VOCABULARY - - -__all__ = ['Transcriptome', 'Gene', 'SegmentGraph', 'SegGraphNode', 'estimate_tpm_threshold', - 'DEFAULT_GENE_FILTER', 'DEFAULT_TRANSCRIPT_FILTER', 'DEFAULT_REF_TRANSCRIPT_FILTER', 'ANNOTATION_VOCABULARY'] +from ._transcriptome_filter import ( + DEFAULT_GENE_FILTER, + DEFAULT_TRANSCRIPT_FILTER, + DEFAULT_REF_TRANSCRIPT_FILTER, + ANNOTATION_VOCABULARY, +) + + +__all__ = [ + "Transcriptome", + "Gene", + "SegmentGraph", + "SegGraphNode", + "estimate_tpm_threshold", + "DEFAULT_GENE_FILTER", + "DEFAULT_TRANSCRIPT_FILTER", + "DEFAULT_REF_TRANSCRIPT_FILTER", + "ANNOTATION_VOCABULARY", +] diff --git a/src/isotools/_gene_plots.py b/src/isotools/_gene_plots.py index 44b5482..28f2057 100644 --- a/src/isotools/_gene_plots.py +++ b/src/isotools/_gene_plots.py @@ -1,4 +1,3 @@ - import collections.abc import matplotlib.colors as plt_col import matplotlib.patches as patches @@ -8,7 +7,8 @@ from math import log10 from ._utils import has_overlap, pairwise import logging -logger = logging.getLogger('isotools') + +logger = logging.getLogger("isotools") def _label_overlap(pos1, pos2, width, height): @@ -17,22 +17,40 @@ def _label_overlap(pos1, pos2, width, height): return False -DEFAULT_JPARAMS = [{'color': 'lightgrey', 'lwd': 1, 'draw_label': False}, # low coverage junctions - {'color': 'green', 'lwd': 1, 'draw_label': True}, # high coverage junctions - {'color': 'purple', 'lwd': 2, 'draw_label': True}] # junctions of interest -DEFAULT_PARAMS = dict(min_cov_th=.001, high_cov_th=.05, text_width=.02, arc_type='both', text_height=1, exon_color='green') -DOMAIN_COLS = {"Family": "red", "Domain": "green", "Repeat": "orange", "Coiled-coil": "blue", "Motif": "grey", "Disordered": "pink"} +DEFAULT_JPARAMS = [ + {"color": "lightgrey", "lwd": 1, "draw_label": False}, # low coverage junctions + {"color": "green", "lwd": 1, "draw_label": True}, # high coverage junctions + {"color": "purple", "lwd": 2, "draw_label": True}, +] # junctions of interest +DEFAULT_PARAMS = dict( + min_cov_th=0.001, + high_cov_th=0.05, + text_width=0.02, + arc_type="both", + text_height=1, + exon_color="green", +) +DOMAIN_COLS = { + "Family": "red", + "Domain": "green", + "Repeat": "orange", + "Coiled-coil": "blue", + "Motif": "grey", + "Disordered": "pink", +} def extend_params(params): if params is None: params = dict() - params.setdefault('jparams', [{}, {}, {}]) + params.setdefault("jparams", [{}, {}, {}]) # jparams=[params.pop(k,jparams[i]) for i,k in enumerate(['low_cov_junctions','high_cov_junctions','interest_junctions'])] - for i, k1 in enumerate(['low_cov_junctions', 'high_cov_junctions', 'interest_junctions']): - params['jparams'][i] = params.pop(k1, params['jparams'][i]) + for i, k1 in enumerate( + ["low_cov_junctions", "high_cov_junctions", "interest_junctions"] + ): + params["jparams"][i] = params.pop(k1, params["jparams"][i]) for k2, v in DEFAULT_JPARAMS[i].items(): - params['jparams'][i].setdefault(k2, v) + params["jparams"][i].setdefault(k2, v) for k, v in DEFAULT_PARAMS.items(): params.setdefault(k, v) return params @@ -49,16 +67,26 @@ def get_index(samples, names): sample_idx = [idx[sample] for sample in samples] except KeyError: notfound = [sample for sample in samples if sample not in idx] - logger.error('did not find the following samples: %s', ','.join(notfound)) + logger.error("did not find the following samples: %s", ",".join(notfound)) raise return sample_idx + # sashimi plots -def sashimi_figure(self, samples=None, short_read_samples=None, draw_gene_track=True, draw_other_genes=True, - long_read_params=None, short_read_params=None, junctions_of_interest=None, x_range=None): - '''Arranges multiple Sashimi plots of the gene. +def sashimi_figure( + self, + samples=None, + short_read_samples=None, + draw_gene_track=True, + draw_other_genes=True, + long_read_params=None, + short_read_params=None, + junctions_of_interest=None, + x_range=None, +): + """Arranges multiple Sashimi plots of the gene. The Sashimi figure consist of a reference gene track, long read sashimi plots for one or more samples or groups of samples, and optionally short read sashimi plots for one or more samples or groups of samples. @@ -73,7 +101,7 @@ def sashimi_figure(self, samples=None, short_read_samples=None, draw_gene_track= See isotools._gene_plots.DEFAULT_PARAMS and isotools._gene_plots.DEFAULT_JPARAMS :param junctions_of_interest: List of int pairs to define junctions of interest (which are highlighed in the plots) :param x_range: Genomic positions to specify the x range of the plot. - :return: Tuple with figure and axses''' + :return: Tuple with figure and axses""" draw_gene_track = bool(draw_gene_track) @@ -82,7 +110,7 @@ def sashimi_figure(self, samples=None, short_read_samples=None, draw_gene_track= if short_read_samples is None: short_read_samples = {} if not samples and not short_read_samples: - samples = {'all': None} + samples = {"all": None} if long_read_params is None: long_read_params = {} if short_read_params is None: @@ -95,19 +123,46 @@ def sashimi_figure(self, samples=None, short_read_samples=None, draw_gene_track= self.gene_track(ax=axes[0], x_range=x_range, draw_other_genes=draw_other_genes) for i, (sname, sidx) in enumerate(samples.items()): - self.sashimi_plot(sidx, sname, axes[i + draw_gene_track], junctions_of_interest, x_range=x_range, **long_read_params) + self.sashimi_plot( + sidx, + sname, + axes[i + draw_gene_track], + junctions_of_interest, + x_range=x_range, + **long_read_params, + ) for i, (sname, sidx) in enumerate(short_read_samples.items()): - self.sashimi_plot_short_reads(sidx, sname, axes[i + len(samples) + draw_gene_track], junctions_of_interest, x_range=x_range, **long_read_params) + self.sashimi_plot_short_reads( + sidx, + sname, + axes[i + len(samples) + draw_gene_track], + junctions_of_interest, + x_range=x_range, + **long_read_params, + ) return f, axes -def sashimi_plot_short_reads(self, samples=None, title='short read coverage', ax=None, junctions_of_interest=None, x_range=None, - y_range=None, log_y=True, - jparams=None, min_cov_th=.001, high_cov_th=.05, text_width=.02, arc_type='both', text_height=1, - exon_color='green'): - '''Draws short read Sashimi plot of the gene. +def sashimi_plot_short_reads( + self, + samples=None, + title="short read coverage", + ax=None, + junctions_of_interest=None, + x_range=None, + y_range=None, + log_y=True, + jparams=None, + min_cov_th=0.001, + high_cov_th=0.05, + text_width=0.02, + arc_type="both", + text_height=1, + exon_color="green", +): + """Draws short read Sashimi plot of the gene. The Sashimi plot depicts the genomic coverage from short read sequencing as blocks, and junction coverage as arcs. @@ -128,11 +183,16 @@ def sashimi_plot_short_reads(self, samples=None, title='short read coverage', ax :param text_width: Control the horizontal space that gets reserved for labels on the arcs. This affects the height of the arcs. :param arc_type: Label the junction arcs with the "coverage" (e.g. number of supporting reads), "fraction" (e.g. fraction of supporting reads in %), or "both". - :param text_height: Control the vertical space that gets reserved for labels on the arcs. This affects the height of the arcs.''' + :param text_height: Control the vertical space that gets reserved for labels on the arcs. This affects the height of the arcs. + """ if samples is None: - samples = list(self._transcriptome.infos['short_reads']['name']) # all samples grouped # pylint: disable=W0212 - sidx = get_index(samples, self._transcriptome.infos['short_reads']['name']) # pylint: disable=W0212 + samples = list( + self._transcriptome.infos["short_reads"]["name"] + ) # all samples grouped # pylint: disable=W0212 + sidx = get_index( + samples, self._transcriptome.infos["short_reads"]["name"] + ) # pylint: disable=W0212 if x_range is None: x_range = (self.start - 100, self.end + 100) @@ -176,8 +236,13 @@ def sashimi_plot_short_reads(self, samples=None, title='short read coverage', ax center = (x1 + x2) / 2 width = x2 - x1 bow_height = text_height - if jparams[priority]['draw_label']: - while any(_label_overlap((center, max(y1, y2) + bow_height), tp, text_width, text_height) for tp in textpositions): + if jparams[priority]["draw_label"]: + while any( + _label_overlap( + (center, max(y1, y2) + bow_height), tp, text_width, text_height + ) + for tp in textpositions + ): bow_height += text_height textpositions.append((center, max(y1, y2) + bow_height)) if y1 < y2: @@ -186,15 +251,39 @@ def sashimi_plot_short_reads(self, samples=None, title='short read coverage', ax bow_height = (bow_height, bow_height + y1 - y2) else: bow_height = (bow_height, bow_height) - bow1 = patches.Arc((center, y1), width=width, height=bow_height[0] * 2, theta1=90, theta2=180, - linewidth=jparams[priority]['lwd'], edgecolor=jparams[priority]['color'], zorder=priority) - bow2 = patches.Arc((center, y2), width=width, height=bow_height[1] * 2, theta1=0, theta2=90, - linewidth=jparams[priority]['lwd'], edgecolor=jparams[priority]['color'], zorder=priority) + bow1 = patches.Arc( + (center, y1), + width=width, + height=bow_height[0] * 2, + theta1=90, + theta2=180, + linewidth=jparams[priority]["lwd"], + edgecolor=jparams[priority]["color"], + zorder=priority, + ) + bow2 = patches.Arc( + (center, y2), + width=width, + height=bow_height[1] * 2, + theta1=0, + theta2=90, + linewidth=jparams[priority]["lwd"], + edgecolor=jparams[priority]["color"], + zorder=priority, + ) ax.add_patch(bow1) ax.add_patch(bow2) - if jparams[priority]['draw_label']: - _ = ax.text(center, max(y1, y2) + min(bow_height) + text_height / 3, w, horizontalalignment='center', verticalalignment='bottom', - bbox=dict(boxstyle='round', facecolor='wheat', edgecolor=None, alpha=0.5)).set_clip_on(True) + if jparams[priority]["draw_label"]: + _ = ax.text( + center, + max(y1, y2) + min(bow_height) + text_height / 3, + w, + horizontalalignment="center", + verticalalignment="bottom", + bbox=dict( + boxstyle="round", facecolor="wheat", edgecolor=None, alpha=0.5 + ), + ).set_clip_on(True) # bbox_list.append(txt.get_tightbbox(renderer = fig.canvas.renderer)) ax.set_xlim(*x_range) @@ -210,19 +299,32 @@ def sashimi_plot_short_reads(self, samples=None, title='short read coverage', ax ax.set_yticklabels([1, 10, 100, 1000]) # ax.ticklabel_format(axis='x', style='sci',scilimits=(6,6)) ax.set_title(title) - ax.xaxis.set_major_formatter(FuncFormatter(lambda x, pos=None: f'{x:,.0f}')) + ax.xaxis.set_major_formatter(FuncFormatter(lambda x, pos=None: f"{x:,.0f}")) return ax -def sashimi_plot(self, samples=None, title='Long read sashimi plot', ax=None, junctions_of_interest=None, x_range=None, select_transcripts=None, - y_range=None, log_y=True, - jparams=None, exon_color='green', min_cov_th=.001, high_cov_th=.05, text_width=1, - arc_type='both', text_height=1): - '''Draws long read Sashimi plot of the gene. +def sashimi_plot( + self, + samples=None, + title="Long read sashimi plot", + ax=None, + junctions_of_interest=None, + x_range=None, + select_transcripts=None, + y_range=None, + log_y=True, + jparams=None, + exon_color="green", + min_cov_th=0.001, + high_cov_th=0.05, + text_width=1, + arc_type="both", + text_height=1, +): + """Draws long read Sashimi plot of the gene. The Sashimi plot depicts the genomic long read sequencing coverage of one or more samples as blocks, and junction coverage as arcs. - :param samples: Names of the samples to be depicted (as a list). :param title: Specify the title of the axis. :param ax: Specify the axis. @@ -242,10 +344,10 @@ def sashimi_plot(self, samples=None, title='Long read sashimi plot', ax=None, ju :param min_cov_th: Coverage threshold for a junction to be considdered at all. :param text_width: Scaling factor for the horizontal space that gets reserved for labels on the arcs. This affects the height of the arcs. - :param arc_type: Label the junction arcs with the "coverage" (e.g. number of supporting reads), + :param arc_type: Label the junction arcs with the "coverage" (e.g. number of supporting reads), "fraction" (e.g. fraction of supporting reads in %), or "both". :param text_height: Scaling factor for the vertical space that gets reserved for labels on the arcs. - This affects the height of the arcs.''' + This affects the height of the arcs.""" sg = self.segment_graph if jparams is None: @@ -265,16 +367,19 @@ def sashimi_plot(self, samples=None, title='Long read sashimi plot', ax=None, ju try: _ = iter(select_transcripts) # maybe only one transcript provided? except TypeError: - select_transcripts = select_transcripts, + select_transcripts = (select_transcripts,) mask = np.ones(node_matrix.shape[0], np.bool) mask[select_transcripts] = False node_matrix[mask, :] = 0 - boxes = [(node[0], node[1], self.coverage[np.ix_(sidx, node_matrix[:, i])].sum()) for i, node in enumerate(sg)] + boxes = [ + (node[0], node[1], self.coverage[np.ix_(sidx, node_matrix[:, i])].sum()) + for i, node in enumerate(sg) + ] if log_y: - boxes = [(s, e, log10(c) if c > 1 else c/10) for s, e, c in boxes] + boxes = [(s, e, log10(c) if c > 1 else c / 10) for s, e, c in boxes] max_height = max(1, max(h for s, e, h in boxes if has_overlap(x_range, (s, e)))) - text_height = (max_height/10)*text_height - text_width = (x_range[1] - x_range[0]) * .02 * text_width + text_height = (max_height / 10) * text_height + text_width = (x_range[1] - x_range[0]) * 0.02 * text_width total_weight = self.coverage[sidx, :].sum() if high_cov_th < 1: @@ -290,10 +395,15 @@ def sashimi_plot(self, samples=None, title='Long read sashimi plot', ax=None, ju continue if select_transcripts is not None and transcript not in select_transcripts: continue - transcript_junction_coverage = self.coverage[np.ix_(sidx, [transcript])].sum() + transcript_junction_coverage = self.coverage[ + np.ix_(sidx, [transcript]) + ].sum() if transcript_junction_coverage: - weights[next_i] = weights.get(next_i, 0)+transcript_junction_coverage - arcs_new = [(ee, boxes[i][2], sg[next_i][0], boxes[next_i][2], w) for next_i, w in weights.items()] + weights[next_i] = weights.get(next_i, 0) + transcript_junction_coverage + arcs_new = [ + (ee, boxes[i][2], sg[next_i][0], boxes[next_i][2], w) + for next_i, w in weights.items() + ] if arcs_new: arcs.extend(arcs_new) if ax is None: @@ -301,7 +411,15 @@ def sashimi_plot(self, samples=None, title='Long read sashimi plot', ax=None, ju for st, end, h in boxes: if h > 0 & has_overlap(x_range, (st, end)): - rect = patches.Rectangle((st, 0), (end - st), h, linewidth=1, edgecolor=exon_color, facecolor=exon_color, zorder=5) + rect = patches.Rectangle( + (st, 0), + (end - st), + h, + linewidth=1, + edgecolor=exon_color, + facecolor=exon_color, + zorder=5, + ) ax.add_patch(rect) textpositions = [] for x1, y1, x2, y2, w in arcs: @@ -316,18 +434,23 @@ def sashimi_plot(self, samples=None, title='Long read sashimi plot', ax=None, ju else: priority = 1 text_x = (x1 + x2) / 2 - textalign = 'center' + textalign = "center" if text_x > x_range[1]: text_x = x_range[1] - textalign = 'right' + textalign = "right" elif text_x < x_range[0]: text_x = x_range[0] - textalign = 'left' + textalign = "left" width = x2 - x1 bow_height = text_height - if jparams[priority]['draw_label']: - while any(_label_overlap((text_x, max(y1, y2) + bow_height), tp, text_width, text_height) for tp in textpositions): + if jparams[priority]["draw_label"]: + while any( + _label_overlap( + (text_x, max(y1, y2) + bow_height), tp, text_width, text_height + ) + for tp in textpositions + ): bow_height += text_height textpositions.append((text_x, max(y1, y2) + bow_height)) if y1 < y2: @@ -336,46 +459,84 @@ def sashimi_plot(self, samples=None, title='Long read sashimi plot', ax=None, ju bow_height = (bow_height, bow_height + y1 - y2) else: bow_height = (bow_height, bow_height) - bow1 = patches.Arc(((x1 + x2) / 2, y1), width=width, height=bow_height[0] * 2, theta1=90, theta2=180, - linewidth=jparams[priority]['lwd'], edgecolor=jparams[priority]['color'], zorder=priority) - bow2 = patches.Arc(((x1 + x2) / 2, y2), width=width, height=bow_height[1] * 2, theta1=0, theta2=90, - linewidth=jparams[priority]['lwd'], edgecolor=jparams[priority]['color'], zorder=priority) + bow1 = patches.Arc( + ((x1 + x2) / 2, y1), + width=width, + height=bow_height[0] * 2, + theta1=90, + theta2=180, + linewidth=jparams[priority]["lwd"], + edgecolor=jparams[priority]["color"], + zorder=priority, + ) + bow2 = patches.Arc( + ((x1 + x2) / 2, y2), + width=width, + height=bow_height[1] * 2, + theta1=0, + theta2=90, + linewidth=jparams[priority]["lwd"], + edgecolor=jparams[priority]["color"], + zorder=priority, + ) ax.add_patch(bow1) ax.add_patch(bow2) - if jparams[priority]['draw_label']: - if arc_type == 'coverage': + if jparams[priority]["draw_label"]: + if arc_type == "coverage": lab = str(w) else: # fraction - lab = f'{w/total_weight:.1%}' - if arc_type == 'both': - lab = str(w) + ' / ' + lab - _ = ax.text(text_x, max(y1, y2) + min(bow_height) + text_height / 3, lab, - horizontalalignment=textalign, verticalalignment='bottom', zorder=10 + priority, - bbox=dict(boxstyle='round', facecolor='wheat', edgecolor=None, alpha=0.5)).set_clip_on(True) + lab = f"{w/total_weight:.1%}" + if arc_type == "both": + lab = str(w) + " / " + lab + _ = ax.text( + text_x, + max(y1, y2) + min(bow_height) + text_height / 3, + lab, + horizontalalignment=textalign, + verticalalignment="bottom", + zorder=10 + priority, + bbox=dict( + boxstyle="round", facecolor="wheat", edgecolor=None, alpha=0.5 + ), + ).set_clip_on(True) # bbox_list.append(txt.get_tightbbox(renderer = fig.canvas.renderer)) if y_range: ax.set_ylim(*y_range) elif textpositions: ax.set_ylim(-text_height, max(tp[1] for tp in textpositions) + 2 * text_height) else: - ax.set_ylim(-text_height, max_height+text_height) + ax.set_ylim(-text_height, max_height + text_height) ax.set_xlim(*x_range) ax.set(frame_on=False) if log_y: ax.set_yticks([0, 1, 2, 3]) ax.set_yticklabels([1, 10, 100, 1000]) - ax.xaxis.set_major_formatter(FuncFormatter(lambda x, pos=None: f'{x:,.0f}')) + ax.xaxis.set_major_formatter(FuncFormatter(lambda x, pos=None: f"{x:,.0f}")) # ax.ticklabel_format(axis='x', style='sci',scilimits=(6,6)) # ax.set_xscale(1e-6, 'linear') ax.set_title(title) -def gene_track(self, ax=None, title=None, reference=True, select_transcripts=None, label_exon_numbers=True, - label_transcripts=True, label_fontsize=10, colorbySqanti=True, color='blue', x_range=None, draw_other_genes=False, - query=None, min_coverage=None, max_coverage=None): - '''Draws a gene track of the gene. +def gene_track( + self, + ax=None, + title=None, + reference=True, + select_transcripts=None, + label_exon_numbers=True, + label_transcripts=True, + label_fontsize=10, + colorbySqanti=True, + color="blue", + x_range=None, + draw_other_genes=False, + query=None, + min_coverage=None, + max_coverage=None, +): + """Draws a gene track of the gene. The gene track depicts the exon structure of a gene, like in a genome browser. Exons are depicted as boxes, and junctions are lines. For coding regions, the height of the boxes is increased. @@ -397,7 +558,8 @@ def gene_track(self, ax=None, title=None, reference=True, select_transcripts=Non You can also provide a list of gene names/ids, to specify which other genes should be included. :param query: Filter query, which is passed to Gene.filter_transcripts or Gene.filter_ref_transcripts :param min_coverage: Minimum coverage for the transcript to be depicted. Ignored in case of reference=True. - :param max_coverage: Maximum coverage for the transcript to be depicted. Ignored in case of reference=True.''' + :param max_coverage: Maximum coverage for the transcript to be depicted. Ignored in case of reference=True. + """ if select_transcripts is None: select_transcripts = {} @@ -409,16 +571,20 @@ def gene_track(self, ax=None, title=None, reference=True, select_transcripts=Non except TypeError: select_transcripts = {self.name: [select_transcripts]} - contrast = 'white' if np.mean(plt_col.to_rgb(color)) < .5 else 'black' + contrast = "white" if np.mean(plt_col.to_rgb(color)) < 0.5 else "black" # there is no Sqanti classification for reference transcripts if reference: colorbySqanti = False if colorbySqanti: - sqanti_palette = {0:{'tag':'FSM', 'color':'#6BAED6'}, 1:{'tag':'ISM', 'color':'#FC8D59'}, - 2:{'tag':'NIC', 'color':'#78C679'}, 3:{'tag':'NNC', 'color':'#EE6A50'}, - 4:{'tag':'NOVEL', 'color':'palevioletred'}} + sqanti_palette = { + 0: {"tag": "FSM", "color": "#6BAED6"}, + 1: {"tag": "ISM", "color": "#FC8D59"}, + 2: {"tag": "NIC", "color": "#78C679"}, + 3: {"tag": "NNC", "color": "#EE6A50"}, + 4: {"tag": "NOVEL", "color": "palevioletred"}, + } if ax is None: _, ax = plt.subplots(1) @@ -429,7 +595,9 @@ def gene_track(self, ax=None, title=None, reference=True, select_transcripts=Non if draw_other_genes: if isinstance(draw_other_genes, list): - ol_genes = {self._transcriptome[gene] for gene in draw_other_genes}.add(self) + ol_genes = {self._transcriptome[gene] for gene in draw_other_genes}.add( + self + ) else: ol_genes = self._transcriptome.data[self.chrom].overlap(*x_range) else: @@ -437,24 +605,59 @@ def gene_track(self, ax=None, title=None, reference=True, select_transcripts=Non transcript_list = [] for gene in ol_genes: - select_tr = gene.filter_ref_transcripts(query) if reference else gene.filter_transcripts(query, min_coverage, max_coverage) + select_tr = ( + gene.filter_ref_transcripts(query) + if reference + else gene.filter_transcripts(query, min_coverage, max_coverage) + ) if select_transcripts.get(gene.name): - select_tr = [transcript_id for transcript_id in select_tr if transcript_id in select_transcripts.get(gene.name)] + select_tr = [ + transcript_id + for transcript_id in select_tr + if transcript_id in select_transcripts.get(gene.name) + ] if reference: # select transcripts and sort by start - transcript_list.extend([(gene, transcript_nr, transcript) for transcript_nr, transcript in enumerate(gene.ref_transcripts) if transcript_nr in select_tr]) + transcript_list.extend( + [ + (gene, transcript_nr, transcript) + for transcript_nr, transcript in enumerate(gene.ref_transcripts) + if transcript_nr in select_tr + ] + ) else: - transcript_list.extend([(gene, transcript_number, transcript) for transcript_number, transcript in enumerate(gene.transcripts) if transcript_number in select_tr]) - transcript_list.sort(key=lambda x: x[2]['exons'][0][0]) # sort by start position + transcript_list.extend( + [ + (gene, transcript_number, transcript) + for transcript_number, transcript in enumerate(gene.transcripts) + if transcript_number in select_tr + ] + ) + transcript_list.sort(key=lambda x: x[2]["exons"][0][0]) # sort by start position for gene, transcript_number, transcript in transcript_list: - transcript_start, transcript_end = transcript['exons'][0][0], transcript['exons'][-1][1] - if (transcript_end < x_range[0] or transcript_start > x_range[1]): # transcript does not overlap x_range + transcript_start, transcript_end = ( + transcript["exons"][0][0], + transcript["exons"][-1][1], + ) + if ( + transcript_end < x_range[0] or transcript_start > x_range[1] + ): # transcript does not overlap x_range continue - transcript_id = '> ' if gene.strand == '+' else '< ' # indicate the strand like in ensembl browser - transcript_id += transcript['transcript_name'] if 'transcript_name' in transcript else f'{gene.name}_{transcript_number}' + transcript_id = ( + "> " if gene.strand == "+" else "< " + ) # indicate the strand like in ensembl browser + transcript_id += ( + transcript["transcript_name"] + if "transcript_name" in transcript + else f"{gene.name}_{transcript_number}" + ) # find next line that is not blocked try: - i = next(idx for idx, last in enumerate(blocked) if last < transcript['exons'][0][0]) + i = next( + idx + for idx, last in enumerate(blocked) + if last < transcript["exons"][0][0] + ) except StopIteration: i = len(blocked) blocked.append(transcript_end) @@ -462,48 +665,96 @@ def gene_track(self, ax=None, title=None, reference=True, select_transcripts=Non blocked[i] = transcript_end # use SQANTI color palette if colorbySqanti is True - if colorbySqanti and 'annotation' in transcript: - color = sqanti_palette[transcript['annotation'][0]]['color'] + if colorbySqanti and "annotation" in transcript: + color = sqanti_palette[transcript["annotation"][0]]["color"] # line from TSS to PAS at 0.25 - ax.plot((transcript_start, transcript_end), [i + .25] * 2, color=color) + ax.plot((transcript_start, transcript_end), [i + 0.25] * 2, color=color) if label_transcripts: - pos = (max(transcript_start, x_range[0]) + min(transcript_end, x_range[1])) / 2 - ax.text(pos, i - .02, transcript_id, ha='center', va='top', fontsize=label_fontsize, clip_on=True) - for j, (start, end) in enumerate(transcript['exons']): + pos = ( + max(transcript_start, x_range[0]) + min(transcript_end, x_range[1]) + ) / 2 + ax.text( + pos, + i - 0.02, + transcript_id, + ha="center", + va="top", + fontsize=label_fontsize, + clip_on=True, + ) + for j, (start, end) in enumerate(transcript["exons"]): cds = None - if 'CDS' in transcript or 'ORF' in transcript: - cds = transcript['CDS'] if 'CDS' in transcript else transcript['ORF'] + if "CDS" in transcript or "ORF" in transcript: + cds = transcript["CDS"] if "CDS" in transcript else transcript["ORF"] if cds is not None and cds[0] <= end and cds[1] >= start: # CODING exon - c_st, c_end = max(start, cds[0]), min(cds[1], end) # coding start and coding end + c_st, c_end = max(start, cds[0]), min( + cds[1], end + ) # coding start and coding end if c_st > start: # first noncoding part - rect = patches.Rectangle((start, i + .125), (c_st - start), .25, linewidth=1, edgecolor=color, facecolor=color) + rect = patches.Rectangle( + (start, i + 0.125), + (c_st - start), + 0.25, + linewidth=1, + edgecolor=color, + facecolor=color, + ) ax.add_patch(rect) if c_end < end: # 2nd noncoding part - rect = patches.Rectangle((c_end, i + .125), (end - c_end), .25, linewidth=1, edgecolor=color, facecolor=color) + rect = patches.Rectangle( + (c_end, i + 0.125), + (end - c_end), + 0.25, + linewidth=1, + edgecolor=color, + facecolor=color, + ) ax.add_patch(rect) # Coding part - rect = patches.Rectangle((c_st, i), (c_end - c_st), .5, linewidth=1, edgecolor=color, facecolor=color) + rect = patches.Rectangle( + (c_st, i), + (c_end - c_st), + 0.5, + linewidth=1, + edgecolor=color, + facecolor=color, + ) ax.add_patch(rect) else: # non coding - rect = patches.Rectangle((start, i + .125), (end - start), .25, linewidth=1, edgecolor=color, facecolor=color) + rect = patches.Rectangle( + (start, i + 0.125), + (end - start), + 0.25, + linewidth=1, + edgecolor=color, + facecolor=color, + ) ax.add_patch(rect) if label_exon_numbers and (end > x_range[0] and start < x_range[1]): - enr = j + 1 if gene.strand == '+' else len(transcript['exons']) - j + enr = j + 1 if gene.strand == "+" else len(transcript["exons"]) - j pos = (max(start, x_range[0]) + min(end, x_range[1])) / 2 - ax.text(pos, i + .25, enr, ha='center', va='center', color=contrast, fontsize=label_fontsize, - clip_on=True) # bbox=dict(boxstyle='round', facecolor='wheat',edgecolor=None, alpha=0.5) + ax.text( + pos, + i + 0.25, + enr, + ha="center", + va="center", + color=contrast, + fontsize=label_fontsize, + clip_on=True, + ) # bbox=dict(boxstyle='round', facecolor='wheat',edgecolor=None, alpha=0.5) i += 1 if title is None: - title = f'{self.name} ({self.region})' + title = f"{self.name} ({self.region})" ax.set_title(title) ax.set(frame_on=False) ax.get_yaxis().set_visible(False) - ax.set_ylim(-.5, len(blocked)) + ax.set_ylim(-0.5, len(blocked)) ax.set_xlim(*x_range) - ax.xaxis.set_major_formatter(FuncFormatter(lambda x, pos=None: f'{x:,.0f}')) + ax.xaxis.set_major_formatter(FuncFormatter(lambda x, pos=None: f"{x:,.0f}")) return ax @@ -513,19 +764,19 @@ def find_blocks(pos, segments, remove_zero_gaps=False): offset = 0 idx = 0 try: - while pos[0] > offset+segments[idx][1] - segments[idx][0]: + while pos[0] > offset + segments[idx][1] - segments[idx][0]: offset += segments[idx][1] - segments[idx][0] idx += 1 - adj_pos = [[segments[idx][0]+pos[0]-offset, None]] - while pos[1] > offset+segments[idx][1] - segments[idx][0]: + adj_pos = [[segments[idx][0] + pos[0] - offset, None]] + while pos[1] > offset + segments[idx][1] - segments[idx][0]: adj_pos[-1][1] = segments[idx][1] offset += segments[idx][1] - segments[idx][0] idx += 1 adj_pos.append([segments[idx][0], None]) except IndexError: - logger.error(f'attempt to postitions {pos} blocks to segments {segments}') + logger.error(f"attempt to postitions {pos} blocks to segments {segments}") raise - adj_pos[-1][1] = segments[idx][0]+pos[1]-offset + adj_pos[-1][1] = segments[idx][0] + pos[1] - offset if remove_zero_gaps: adj_pos_gaps = adj_pos adj_pos = [] @@ -537,36 +788,69 @@ def find_blocks(pos, segments, remove_zero_gaps=False): return adj_pos -def get_rects(blocks, h=1, w=.1, connect=False, **kwargs): - rects = [patches.Rectangle((b[0], h-w/2), b[1]-b[0], w, **kwargs) for b in blocks] +def get_rects(blocks, h=1, w=0.1, connect=False, **kwargs): + rects = [ + patches.Rectangle((b[0], h - w / 2), b[1] - b[0], w, **kwargs) for b in blocks + ] # Rectangle(xy=lower left, width, height) if connect: # draw a line between blocks - rects.extend([patches.Polygon(np.array([[b1[1], h], [b2[0], h]]), closed=True, **kwargs) for b1, b2 in pairwise(blocks) if b1[1] < b2[0]]) - return (rects) - - -def get_patches(blocks, orf, h, w1=.1, w2=.5, connect=True, **kwargs): + rects.extend( + [ + patches.Polygon( + np.array([[b1[1], h], [b2[0], h]]), closed=True, **kwargs + ) + for b1, b2 in pairwise(blocks) + if b1[1] < b2[0] + ] + ) + return rects + + +def get_patches(blocks, orf, h, w1=0.1, w2=0.5, connect=True, **kwargs): rects = [] - y11, y21 = h+w1/2, h-w1/2 # y for the smaller blocks - y12, y22 = h+w2/2, h-w2/2 # y for the larger blocks + y11, y21 = h + w1 / 2, h - w1 / 2 # y for the smaller blocks + y12, y22 = h + w2 / 2, h - w2 / 2 # y for the larger blocks if orf is None: orf = [blocks[0][0], blocks[0][0]] # 5'UTR blocks (small) - rects = [patches.Rectangle((b[0], y21), b[1]-b[0], w1, **kwargs) for b in blocks if b[1] <= orf[0]] + rects = [ + patches.Rectangle((b[0], y21), b[1] - b[0], w1, **kwargs) + for b in blocks + if b[1] <= orf[0] + ] # transition to CDS for b in blocks: if b[0] < orf[0] and b[1] > orf[0]: if b[1] > orf[1]: # transition in and out - x = [b[0], orf[0], orf[0], orf[1], orf[1], b[1], b[1], orf[1], orf[1], orf[0], orf[0], b[0]] - y = [y11, y11, y12, y12, y11, y11, y21, y21, y22, y22, y21, y21] + x = [ + b[0], + orf[0], + orf[0], + orf[1], + orf[1], + b[1], + b[1], + orf[1], + orf[1], + orf[0], + orf[0], + b[0], + ] + y = [y11, y11, y12, y12, y11, y11, y21, y21, y22, y22, y21, y21] else: # transition to CDS x = [b[0], orf[0], orf[0], b[1], b[1], orf[0], orf[0], b[0]] y = [y11, y11, y12, y12, y22, y22, y21, y21] rects.append(patches.Polygon(list(zip(x, y)), closed=True, **kwargs)) # CDS blocks(large) - rects.extend([patches.Rectangle((b[0], y22), b[1]-b[0], w2, **kwargs) for b in blocks if b[0] >= orf[0] and b[1] <= orf[1]]) + rects.extend( + [ + patches.Rectangle((b[0], y22), b[1] - b[0], w2, **kwargs) + for b in blocks + if b[0] >= orf[0] and b[1] <= orf[1] + ] + ) # transition to 3'UTR for b in blocks: if b[0] > orf[0] and b[0] < orf[1] and b[1] > orf[1]: @@ -574,32 +858,55 @@ def get_patches(blocks, orf, h, w1=.1, w2=.5, connect=True, **kwargs): y = [y12, y12, y11, y11, y21, y21, y22, y22] rects.append(patches.Polygon(list(zip(x, y)), closed=True, **kwargs)) # 3'UTR blocks (small) - rects.extend([patches.Rectangle((b[0], y21), b[1]-b[0], w1, **kwargs) for b in blocks if b[0] >= orf[1]]) + rects.extend( + [ + patches.Rectangle((b[0], y21), b[1] - b[0], w1, **kwargs) + for b in blocks + if b[0] >= orf[1] + ] + ) if connect: # draw a line between blocks - rects.extend([patches.Polygon(np.array([[b1[1], h], [b2[0], h]]), closed=False, **kwargs) for b1, b2 in pairwise(blocks) if b1[1] < b2[0]]) - return (rects) + rects.extend( + [ + patches.Polygon( + np.array([[b1[1], h], [b2[0], h]]), closed=False, **kwargs + ) + for b1, b2 in pairwise(blocks) + if b1[1] < b2[0] + ] + ) + return rects def find_segments(transcripts, orf_only=True, separate_exons=False): - '''Find exonic parts of the gene, with respect to transcript_ids.''' + """Find exonic parts of the gene, with respect to transcript_ids.""" if orf_only: exon_list = [] for transcript in transcripts: - cds_pos = transcript.get('CDS', transcript.get('ORF')) + cds_pos = transcript.get("CDS", transcript.get("ORF")) exon_list.append([]) if cds_pos is None: continue - for exon in transcript['exons']: + for exon in transcript["exons"]: if exon[1] < cds_pos[0]: continue if exon[0] > cds_pos[1]: break - exon_list[-1].append([max(exon[0], cds_pos[0]), min(exon[1], cds_pos[1])]) + exon_list[-1].append( + [max(exon[0], cds_pos[0]), min(exon[1], cds_pos[1])] + ) else: - exon_list = [transcript['exons'] for transcript in transcripts] - - junctions = sorted([(pos, bool(j), i) for i, cds in enumerate(exon_list) for e in cds for j, pos in enumerate(e)]) + exon_list = [transcript["exons"] for transcript in transcripts] + + junctions = sorted( + [ + (pos, bool(j), i) + for i, cds in enumerate(exon_list) + for e in cds + for j, pos in enumerate(e) + ] + ) open_c = 0 offset = 0 genome_map = [] # genomic interval, e.g.[([12345, 12445],0) ([12900-12980],100)] @@ -608,9 +915,11 @@ def find_segments(transcripts, orf_only=True, separate_exons=False): pre_pos = None for pos, is_end, tr_i in junctions: if open_c > 0: - offset += pos-pre_pos + offset += pos - pre_pos else: - assert not is_end, f'more exons closed than opened before: {pos} at {junctions}' + assert ( + not is_end + ), f"more exons closed than opened before: {pos} at {junctions}" genome_map.append([pos, None]) if not is_end: if separate_exons or not segments[tr_i] or segments[tr_i][-1][1] < offset: @@ -636,7 +945,7 @@ def genome_pos_to_gene_segments(pos, genome_map, strict=True): for seg in genome_map: while seg[1] >= pos[i]: if seg[0] <= pos[i]: - mapped_pos.append(offset+pos[i]-seg[0]) + mapped_pos.append(offset + pos[i] - seg[0]) elif not strict: mapped_pos.append(offset) else: @@ -645,21 +954,37 @@ def genome_pos_to_gene_segments(pos, genome_map, strict=True): if i == len(pos): break else: - offset += seg[1]-seg[0] + offset += seg[1] - seg[0] continue break else: - for i in range(i, len(pos)): + for _i in range(i, len(pos)): mapped_pos.append(None if strict else offset) if reverse_strand: - trlen = sum(seg[1]-seg[0] for seg in genome_map) - mapped_pos = [trlen-mp if mp is not None else None for mp in mapped_pos] + trlen = sum(seg[1] - seg[0] for seg in genome_map) + mapped_pos = [trlen - mp if mp is not None else None for mp in mapped_pos] return {p: mp for p, mp in zip(pos, mapped_pos)} -def plot_domains(self, source, categories=None, transcript_ids=True, ref_transcript_ids=False, coding_only=True, label='name', include_utr=False, separate_exons=True, - x_ticks='gene', ax=None, dom_space=.8, domain_cols=DOMAIN_COLS, max_overlap=5, highlight=None, highlight_col='red'): - '''Plot exonic part of transcripts, together with protein domains and annotations. +def plot_domains( + self, + source, + categories=None, + transcript_ids=True, + ref_transcript_ids=False, + coding_only=True, + label="name", + include_utr=False, + separate_exons=True, + x_ticks="gene", + ax=None, + dom_space=0.8, + domain_cols=DOMAIN_COLS, + max_overlap=5, + highlight=None, + highlight_col="red", +): + """Plot exonic part of transcripts, together with protein domains and annotations. :param source: Source of protein domains, e.g. "annotation", "hmmer" or "interpro", for domains added by the functions "add_annotation_domains", "add_hmmer_domains" or "add_interpro_domains" respectively. @@ -677,38 +1002,65 @@ def plot_domains(self, source, categories=None, transcript_ids=True, ref_transcr :param domain_cols: Dicionary for the colors of different domain types. :param max_overlap: Maximum number of overlapping domains to be depicted. Longer domains have priority over shorter domains. :param highlight: List of genomic positions or intervals to highlight. - :param highlight_col: Specify the color for highlight positions.''' + :param highlight_col: Specify the color for highlight positions.""" if label is not None: - assert label in ('id', 'name'), 'label needs to be either "id" or "name" (or None).' - label_idx = 0 if label == 'id' else 1 + assert label in ( + "id", + "name", + ), 'label needs to be either "id" or "name" (or None).' + label_idx = 0 if label == "id" else 1 - assert 0 < dom_space <= 1, 'dom_space should be between 0 and 1.' + assert 0 < dom_space <= 1, "dom_space should be between 0 and 1." domain_cols = {k.lower(): v for k, v in domain_cols.items()} - assert x_ticks in ["gene", "genome"], f'x_ticks should be "gene" or "genome", not "{x_ticks}"' + assert x_ticks in [ + "gene", + "genome", + ], f'x_ticks should be "gene" or "genome", not "{x_ticks}"' if not include_utr: - assert coding_only, 'coding_only can be set only if include_utr is also set.' + assert coding_only, "coding_only can be set only if include_utr is also set." if isinstance(transcript_ids, bool): transcript_ids = list(range(len(self.transcripts))) if transcript_ids else [] if coding_only: - transcript_ids = [transcript_id for transcript_id in transcript_ids if 'ORF' in self.transcripts[transcript_id] or 'CDS' in self.transcripts[transcript_id]] + transcript_ids = [ + transcript_id + for transcript_id in transcript_ids + if "ORF" in self.transcripts[transcript_id] + or "CDS" in self.transcripts[transcript_id] + ] if isinstance(ref_transcript_ids, bool): - ref_transcript_ids = list(range(len(self.ref_transcripts))) if ref_transcript_ids else [] + ref_transcript_ids = ( + list(range(len(self.ref_transcripts))) if ref_transcript_ids else [] + ) if coding_only: - ref_transcript_ids = [transcript_id for transcript_id in ref_transcript_ids if 'ORF' in self.ref_transcripts[transcript_id] or 'CDS' in self.ref_transcripts[transcript_id]] - transcripts = [(i, self.ref_transcripts[i]) for i in ref_transcript_ids] + [(i, self.transcripts[i]) for i in transcript_ids] + ref_transcript_ids = [ + transcript_id + for transcript_id in ref_transcript_ids + if "ORF" in self.ref_transcripts[transcript_id] + or "CDS" in self.ref_transcripts[transcript_id] + ] + transcripts = [(i, self.ref_transcripts[i]) for i in ref_transcript_ids] + [ + (i, self.transcripts[i]) for i in transcript_ids + ] n_transcripts = len(transcripts) if not transcripts: - logger.error('no transcripts with ORF specified') + logger.error("no transcripts with ORF specified") return if ax is None: _, ax = plt.subplots(1) skipped = 0 - segments, genome_map = find_segments([transcript for _, transcript in transcripts], orf_only=not include_utr, separate_exons=separate_exons) + segments, genome_map = find_segments( + [transcript for _, transcript in transcripts], + orf_only=not include_utr, + separate_exons=separate_exons, + ) max_len = max(seg[-1][1] for seg in segments) - assert max_len == sum(seg[1]-seg[0] for seg in genome_map) + assert max_len == sum(seg[1] - seg[0] for seg in genome_map) if self.strand == "-": - segments = [[[max_len-pos[1], max_len-pos[0]] for pos in reversed(seg)] for seg in segments] + segments = [ + [[max_len - pos[1], max_len - pos[0]] for pos in reversed(seg)] + for seg in segments + ] genome_map = tuple((pos[1], pos[0]) for pos in reversed(genome_map)) if highlight is not None: highlight_pos = set() @@ -721,10 +1073,16 @@ def plot_domains(self, source, categories=None, transcript_ids=True, ref_transcr pos_map = genome_pos_to_gene_segments(highlight_pos, genome_map, False) for pos in highlight: if isinstance(pos, collections.abc.Sequence): - assert len(pos) == 2, 'provide intervals as a sequence of length 2' + assert len(pos) == 2, "provide intervals as a sequence of length 2" # draw box box_x = sorted(pos_map[p] for p in pos) - patch = patches.Rectangle((box_x[0], -n_transcripts), box_x[1]-box_x[0], n_transcripts+1, edgecolor=highlight_col, facecolor=highlight_col) + patch = patches.Rectangle( + (box_x[0], -n_transcripts), + box_x[1] - box_x[0], + n_transcripts + 1, + edgecolor=highlight_col, + facecolor=highlight_col, + ) ax.add_patch(patch) else: # draw line ax.vlines(pos_map[pos], -n_transcripts, 1, colors=[highlight_col]) @@ -733,8 +1091,12 @@ def plot_domains(self, source, categories=None, transcript_ids=True, ref_transcr seg = segments[line] if include_utr: try: - orf_pos = transcript.get('CDS', transcript['ORF'])[:2] - orf_trpos = sorted(self.find_transcript_positions(transcript_id, orf_pos, reference=line < len(ref_transcript_ids))) + orf_pos = transcript.get("CDS", transcript["ORF"])[:2] + orf_trpos = sorted( + self.find_transcript_positions( + transcript_id, orf_pos, reference=line < len(ref_transcript_ids) + ) + ) orf_blocks = find_blocks(orf_trpos, seg, True) orf_segpos = [orf_blocks[0][0], orf_blocks[-1][1]] except KeyError: @@ -744,55 +1106,108 @@ def plot_domains(self, source, categories=None, transcript_ids=True, ref_transcr else: orf_segpos = [0, seg[-1][1]] orf_trpos = [0, None] - for rect in get_patches(seg, orf_segpos, h=-line, connect=True, - linewidth=1, edgecolor="black", facecolor="white"): + for rect in get_patches( + seg, + orf_segpos, + h=-line, + connect=True, + linewidth=1, + edgecolor="black", + facecolor="white", + ): ax.add_patch(rect) if orf_segpos is None: continue - domains = [dom for dom in transcript.get('domain', {}).get(source, []) if categories is None or dom[2] in categories] + domains = [ + dom + for dom in transcript.get("domain", {}).get(source, []) + if categories is None or dom[2] in categories + ] # sort by length - domains.sort(key=lambda x: x[3][1]-x[3][0], reverse=True) + domains.sort(key=lambda x: x[3][1] - x[3][0], reverse=True) # get positions relative to segments - dom_blocks = [find_blocks([p+orf_trpos[0] for p in dom[3]], seg, True) for dom in domains] + dom_blocks = [ + find_blocks([p + orf_trpos[0] for p in dom[3]], seg, True) + for dom in domains + ] dom_line = {} for idx, block in enumerate(dom_blocks): i = 0 block_interval = (block[0][0], block[-1][1]) - while any(has_overlap(block_interval, b[1]) for b in dom_line.setdefault(i, [])): + while any( + has_overlap(block_interval, b[1]) for b in dom_line.setdefault(i, []) + ): i += 1 if i >= max_overlap: skipped += 1 break else: - dom_line[i].append((idx, block_interval)) # idx in length-sorted domains + dom_line[i].append( + (idx, block_interval) + ) # idx in length-sorted domains - w = dom_space*.5/max(len(dom_line), 1) + w = dom_space * 0.5 / max(len(dom_line), 1) def get_line_y(i, n): - return n//2 + (i+1)//2 * (-1 if i % 2 else 1) + return n // 2 + (i + 1) // 2 * (-1 if i % 2 else 1) + for dom_l in dom_line: - h = -line+w*get_line_y(dom_l, len(dom_line)) + h = -line + w * get_line_y(dom_l, len(dom_line)) # ugly hack to make the domains align with the proteins - h -= w*(get_line_y(len(dom_line)-1, len(dom_line))+get_line_y(len(dom_line)-2, len(dom_line)))/2 + h -= ( + w + * ( + get_line_y(len(dom_line) - 1, len(dom_line)) + + get_line_y(len(dom_line) - 2, len(dom_line)) + ) + / 2 + ) for idx, bl in dom_line[dom_l]: dom = domains[idx] try: - for rect in get_rects(dom_blocks[idx], h=h, w=w, linewidth=1, edgecolor="black", facecolor=domain_cols.get(dom[2].lower(), "white")): + for rect in get_rects( + dom_blocks[idx], + h=h, + w=w, + linewidth=1, + edgecolor="black", + facecolor=domain_cols.get(dom[2].lower(), "white"), + ): ax.add_patch(rect) except IndexError: - logger.error(f'cannot add patch for {dom_blocks[idx]}') + logger.error(f"cannot add patch for {dom_blocks[idx]}") raise if label is not None: - ax.text((bl[0]+bl[1])/2, h, dom[label_idx], ha='center', va='center', color='black', clip_on=True) + ax.text( + (bl[0] + bl[1]) / 2, + h, + dom[label_idx], + ha="center", + va="center", + color="black", + clip_on=True, + ) if skipped: - logger.warning("skipped %s domains, consider increasing max_overlap parameter", skipped) - ax.set_ylim(-len(transcripts)+.25, .75) - ax.set_xlim(-10, max_len+10) - if x_ticks == 'genome': - xticks = [0]+list(np.cumsum([abs(seg[1]-seg[0]) for seg in genome_map])) - xticklabels = [str(genome_map[0][0])]+[f'{seg[0][1]}|{seg[1][0]}' for seg in pairwise(genome_map)] + [str(genome_map[-1][1])] + logger.warning( + "skipped %s domains, consider increasing max_overlap parameter", skipped + ) + ax.set_ylim(-len(transcripts) + 0.25, 0.75) + ax.set_xlim(-10, max_len + 10) + if x_ticks == "genome": + xticks = [0] + list(np.cumsum([abs(seg[1] - seg[0]) for seg in genome_map])) + xticklabels = ( + [str(genome_map[0][0])] + + [f"{seg[0][1]}|{seg[1][0]}" for seg in pairwise(genome_map)] + + [str(genome_map[-1][1])] + ) ax.set_xticks(ticks=xticks, labels=xticklabels) - ax.set_yticks(ticks=[-i for i in range(len(transcripts))], labels=[transcript.get('transcript_name', f'{self.name} {transcript_id}') for transcript_id, transcript in transcripts]) + ax.set_yticks( + ticks=[-i for i in range(len(transcripts))], + labels=[ + transcript.get("transcript_name", f"{self.name} {transcript_id}") + for transcript_id, transcript in transcripts + ], + ) return ax, genome_map diff --git a/src/isotools/_transcriptome_filter.py b/src/isotools/_transcriptome_filter.py index c4c38bc..9a52fe5 100644 --- a/src/isotools/_transcriptome_filter.py +++ b/src/isotools/_transcriptome_filter.py @@ -9,51 +9,85 @@ from .transcriptome import Transcriptome from .gene import Gene -logger = logging.getLogger('isotools') -BOOL_OP = {'and', 'or', 'not', 'is'} -DEFAULT_GENE_FILTER = {'NOVEL_GENE': 'not reference', - 'EXPRESSED': 'transcripts', - 'CHIMERIC': 'chimeric'} +logger = logging.getLogger("isotools") +BOOL_OP = {"and", "or", "not", "is"} +DEFAULT_GENE_FILTER = { + "NOVEL_GENE": "not reference", + "EXPRESSED": "transcripts", + "CHIMERIC": "chimeric", +} DEFAULT_REF_TRANSCRIPT_FILTER = { - 'REF_UNSPLICED': 'len(exons)==1', - 'REF_MULTIEXON': 'len(exons)>1', - 'REF_INTERNAL_PRIMING': 'downstream_A_content>.5'} + "REF_UNSPLICED": "len(exons)==1", + "REF_MULTIEXON": "len(exons)>1", + "REF_INTERNAL_PRIMING": "downstream_A_content>.5", +} DEFAULT_TRANSCRIPT_FILTER = { # 'CLIPPED_ALIGNMENT':'clipping', - 'INTERNAL_PRIMING': 'len(exons)==1 and downstream_A_content and downstream_A_content>.5', # more than 50% a - 'RTTS': 'noncanonical_splicing is not None and novel_splice_sites is not None and \ - any(2*i in novel_splice_sites and 2*i+1 in novel_splice_sites for i,_ in noncanonical_splicing)', - 'NONCANONICAL_SPLICING': 'noncanonical_splicing', - 'NOVEL_TRANSCRIPT': 'annotation[0]>0', - 'FRAGMENT': 'fragments and any("novel exonic " in a or "fragment" in a for a in annotation[1])', - 'UNSPLICED': 'len(exons)==1', - 'MULTIEXON': 'len(exons)>1', - 'SUBSTANTIAL': 'gene.coverage.sum() * .01 < gene.coverage[:,trid].sum()', - 'HIGH_COVER': 'gene.coverage.sum(0)[trid] >= 7', - 'PERMISSIVE': 'gene.coverage.sum(0)[trid] >= 2 and (FSM or not (RTTS or INTERNAL_PRIMING or FRAGMENT))', - 'BALANCED': 'gene.coverage.sum(0)[trid] >= 2 and (FSM or (HIGH_COVER and not (RTTS or FRAGMENT or INTERNAL_PRIMING)))', - 'STRICT': 'gene.coverage.sum(0)[trid] >= 7 and SUBSTANTIAL and (FSM or not (RTTS or FRAGMENT or INTERNAL_PRIMING))', - 'CAGE_SUPPORT': 'sqanti_classification is not None and sqanti_classification["within_CAGE_peak"]', - 'TSS_RATIO': 'sqanti_classification is not None and sqanti_classification["ratio_TSS"] > 1.5', - 'POLYA_MOTIF': 'sqanti_classification is not None and sqanti_classification["polyA_motif_found"]', - 'POLYA_SITE': 'sqanti_classification is not None and sqanti_classification["within_polyA_site"]', + "INTERNAL_PRIMING": "len(exons)==1 and downstream_A_content and downstream_A_content>.5", # more than 50% a + "RTTS": "noncanonical_splicing is not None and novel_splice_sites is not None and \ + any(2*i in novel_splice_sites and 2*i+1 in novel_splice_sites for i,_ in noncanonical_splicing)", + "NONCANONICAL_SPLICING": "noncanonical_splicing", + "NOVEL_TRANSCRIPT": "annotation[0]>0", + "FRAGMENT": 'fragments and any("novel exonic " in a or "fragment" in a for a in annotation[1])', + "UNSPLICED": "len(exons)==1", + "MULTIEXON": "len(exons)>1", + "SUBSTANTIAL": "gene.coverage.sum() * .01 < gene.coverage[:,trid].sum()", + "HIGH_COVER": "gene.coverage.sum(0)[trid] >= 7", + "PERMISSIVE": "gene.coverage.sum(0)[trid] >= 2 and (FSM or not (RTTS or INTERNAL_PRIMING or FRAGMENT))", + "BALANCED": "gene.coverage.sum(0)[trid] >= 2 and (FSM or (HIGH_COVER and not (RTTS or FRAGMENT or INTERNAL_PRIMING)))", + "STRICT": "gene.coverage.sum(0)[trid] >= 7 and SUBSTANTIAL and (FSM or not (RTTS or FRAGMENT or INTERNAL_PRIMING))", + "CAGE_SUPPORT": 'sqanti_classification is not None and sqanti_classification["within_CAGE_peak"]', + "TSS_RATIO": 'sqanti_classification is not None and sqanti_classification["ratio_TSS"] > 1.5', + "POLYA_MOTIF": 'sqanti_classification is not None and sqanti_classification["polyA_motif_found"]', + "POLYA_SITE": 'sqanti_classification is not None and sqanti_classification["within_polyA_site"]', } -SPLICE_CATEGORY = ['FSM', 'ISM', 'NIC', 'NNC', 'NOVEL'] - - -ANNOTATION_VOCABULARY = ['antisense', 'intergenic', 'genic genomic', 'novel exonic PAS', 'novel intronic PAS', 'readthrough fusion', - 'novel exon', "novel 3' splice site", 'intron retention', "novel 5' splice site", 'exon skipping', 'novel combination', - 'novel intronic TSS', 'novel exonic TSS', 'mono-exon', 'novel junction', "5' fragment", "3' fragment", 'intronic'] +SPLICE_CATEGORY = ["FSM", "ISM", "NIC", "NNC", "NOVEL"] + + +ANNOTATION_VOCABULARY = [ + "antisense", + "intergenic", + "genic genomic", + "novel exonic PAS", + "novel intronic PAS", + "readthrough fusion", + "novel exon", + "novel 3' splice site", + "intron retention", + "novel 5' splice site", + "exon skipping", + "novel combination", + "novel intronic TSS", + "novel exonic TSS", + "mono-exon", + "novel junction", + "5' fragment", + "3' fragment", + "intronic", +] # filtering functions for the transcriptome class -def add_orf_prediction(self: 'Transcriptome', genome_fn, progress_bar=True, filter_transcripts={}, filter_ref_transcripts={}, min_len=300, max_5utr_len=500, - min_kozak=None, prefer_annotated_init=True, kozak_matrix=DEFAULT_KOZAK_PWM, fickett_score=True, hexamer_file=None): - ''' Performs ORF prediction on the transcripts. + +def add_orf_prediction( + self: "Transcriptome", + genome_fn, + progress_bar=True, + filter_transcripts=None, + filter_ref_transcripts=None, + min_len=300, + max_5utr_len=500, + min_kozak=None, + prefer_annotated_init=True, + kozak_matrix=DEFAULT_KOZAK_PWM, + fickett_score=True, + hexamer_file=None, +): + """Performs ORF prediction on the transcripts. For each transcript the first valid open reading frame is determined, and metrics to assess the coding potential (UTR and CDS lengths, Kozak score, Fickett score, hexamer score and NMD prediction). The hexamer score depends on hexamer frequency table, @@ -66,7 +100,12 @@ def add_orf_prediction(self: 'Transcriptome', genome_fn, progress_bar=True, filt :param prefer_annotated_init: If True, the initiation sites of annotated CDS are preferred. :param kozak_matrix: PWM (log odds ratios) for the kozak sequence similarity score. :param fickett_score: If set to True, the Fickett TESTCODE score is computed for the ORF. - :param hexamer_file: Filename of the hexamer table, for the ORF hexamer scores. If set not None, the hexamer score is not computed.''' + :param hexamer_file: Filename of the hexamer table, for the ORF hexamer scores. If set not None, the hexamer score is not computed. + """ + if filter_transcripts is None: + filter_transcripts = {} + if filter_ref_transcripts is None: + filter_ref_transcripts = {} if hexamer_file is None: coding = None @@ -77,7 +116,7 @@ def add_orf_prediction(self: 'Transcriptome', genome_fn, progress_bar=True, filt for line in open(hexamer_file): line = line.strip() fields = line.split() - if fields[0] == 'hexamer': + if fields[0] == "hexamer": continue coding[fields[0]] = float(fields[1]) noncoding[fields[0]] = float(fields[2]) @@ -86,24 +125,58 @@ def add_orf_prediction(self: 'Transcriptome', genome_fn, progress_bar=True, filt missing_chr = set(self.chromosomes) - set(genome_fh.references) if missing_chr: missing_genes = sum(len(self.data[mc]) for mc in missing_chr) - logger.warning('%s contigs are not contained in genome, affecting %s genes. \ - ORFs cannot be computed for these contigs: %s', str(len(missing_chr)), str(missing_genes), str(missing_chr)) + logger.warning( + "%s contigs are not contained in genome, affecting %s genes. \ + ORFs cannot be computed for these contigs: %s", + str(len(missing_chr)), + str(missing_genes), + str(missing_chr), + ) for gene in self.iter_genes(progress_bar=progress_bar): if gene.chrom in genome_fh.references: if filter_transcripts is not None: - gene.add_orfs(genome_fh, reference=False, prefer_annotated_init=prefer_annotated_init, minlen=min_len, - min_kozak=min_kozak, max_5utr_len=max_5utr_len, tr_filter=filter_transcripts, - kozak_matrix=kozak_matrix, get_fickett=fickett_score, coding_hexamers=coding, noncoding_hexamers=noncoding) + gene.add_orfs( + genome_fh, + reference=False, + prefer_annotated_init=prefer_annotated_init, + minlen=min_len, + min_kozak=min_kozak, + max_5utr_len=max_5utr_len, + tr_filter=filter_transcripts, + kozak_matrix=kozak_matrix, + get_fickett=fickett_score, + coding_hexamers=coding, + noncoding_hexamers=noncoding, + ) if filter_ref_transcripts is not None: - gene.add_orfs(genome_fh, reference=True, prefer_annotated_init=prefer_annotated_init, minlen=min_len, - min_kozak=min_kozak, max_5utr_len=max_5utr_len, tr_filter=filter_ref_transcripts, - get_fickett=fickett_score, kozak_matrix=kozak_matrix, coding_hexamers=coding, noncoding_hexamers=noncoding) - - -def add_qc_metrics(self: 'Transcriptome', genome_fn: str, progress_bar=True, downstream_a_len=30, direct_repeat_wd=15, direct_repeat_wobble=2, direct_repeat_mm=2, - unify_ends=True, correct_tss=True): - ''' + gene.add_orfs( + genome_fh, + reference=True, + prefer_annotated_init=prefer_annotated_init, + minlen=min_len, + min_kozak=min_kozak, + max_5utr_len=max_5utr_len, + tr_filter=filter_ref_transcripts, + get_fickett=fickett_score, + kozak_matrix=kozak_matrix, + coding_hexamers=coding, + noncoding_hexamers=noncoding, + ) + + +def add_qc_metrics( + self: "Transcriptome", + genome_fn: str, + progress_bar=True, + downstream_a_len=30, + direct_repeat_wd=15, + direct_repeat_wobble=2, + direct_repeat_mm=2, + unify_ends=True, + correct_tss=True, +): + """ Retrieves QC metrics for the transcripts. Calling this function populates transcript["biases"] information, which can be used do create filters. @@ -117,43 +190,53 @@ def add_qc_metrics(self: 'Transcriptome', genome_fn: str, progress_bar=True, dow :param direct_repeat_mm: Maximum number of missmatches in a direct repeat. :param unify_ends: If set, the TSS and PAS are unified using peak calling. :param correct_tss: If set TSS are corrected with respect to the reference annotation. Only used if unify_ends is set. - ''' + """ with FastaFile(genome_fn) as genome_fh: missing_chr = set(self.chromosomes) - set(genome_fh.references) if missing_chr: missing_genes = sum(len(self.data[mc]) for mc in missing_chr) - logger.warning('%s contigs are not contained in genome, affecting %s genes. \ - Some metrics cannot be computed: %s', str(len(missing_chr)), str(missing_genes), str(missing_chr)) + logger.warning( + "%s contigs are not contained in genome, affecting %s genes. \ + Some metrics cannot be computed: %s", + str(len(missing_chr)), + str(missing_genes), + str(missing_chr), + ) for gene in self.iter_genes(progress_bar=progress_bar): if unify_ends: # remove segment graph (if unify TSS/PAS option selected) - gene.data['segment_graph'] = None + gene.data["segment_graph"] = None # "unify" TSS/PAS (if unify TSS/PAS option selected) gene._unify_ends(correct_tss=correct_tss) # compute segment graph (if not present) _ = gene.segment_graph gene.add_fragments() if gene.chrom in genome_fh.references: - gene.add_direct_repeat_len(genome_fh, delta=direct_repeat_wd, max_mm=direct_repeat_mm, wobble=direct_repeat_wobble) + gene.add_direct_repeat_len( + genome_fh, + delta=direct_repeat_wd, + max_mm=direct_repeat_mm, + wobble=direct_repeat_wobble, + ) gene.add_noncanonical_splicing(genome_fh) gene.add_threeprime_a_content(genome_fh, length=downstream_a_len) - self.infos['biases'] = True # flag to check that the function was called + self.infos["biases"] = True # flag to check that the function was called def remove_filter(self, tag): - '''Removes definition of filter tag. + """Removes definition of filter tag. - :param tag: Specify the tag of the filter definition to remove.''' + :param tag: Specify the tag of the filter definition to remove.""" old = [f.pop(tag, None) for f in self.filter.values()] if not any(old): - logger.error('filter tag %s not found', tag) + logger.error("filter tag %s not found", tag) -def add_filter(self, tag, expression, context='transcript', update=False): - '''Defines a new filter for gene, transcripts and reference transcripts. +def add_filter(self, tag, expression, context="transcript", update=False): + """Defines a new filter for gene, transcripts and reference transcripts. The provided expressions is evaluated during filtering in the provided context, when specified in a query string of a function that supports filtering. Importantly, filtering does not modify the original data; rather, it is only applied when specifying the query string. @@ -164,20 +247,44 @@ def add_filter(self, tag, expression, context='transcript', update=False): :param expression: Expression to be evaluated on gene, transcript, or reference transcript. Can use existing filters from the same context. :param context: The context for the filter expression, either "gene", "transcript" or "reference". - :param update: If set, the already present definition of the provided tag gets overwritten.''' - - assert context in ['gene', 'transcript', 'reference'], "filter context must be 'gene', 'transcript' or 'reference'" - assert tag == re.findall(r'\b\w+\b', tag)[0], '"tag" must be a single word' + :param update: If set, the already present definition of the provided tag gets overwritten. + """ + + assert context in [ + "gene", + "transcript", + "reference", + ], "filter context must be 'gene', 'transcript' or 'reference'" + assert tag == re.findall(r"\b\w+\b", tag)[0], '"tag" must be a single word' if not update: - assert tag not in self.filter[context], f"Filter tag {tag} is already present: `{self.filter[context][tag]}`. Set update=True to re-define." - if context == 'gene': + assert ( + tag not in self.filter[context] + ), f"Filter tag {tag} is already present: `{self.filter[context][tag]}`. Set update=True to re-define." + if context == "gene": attributes = {k for gene in self for k in gene.data.keys() if k.isidentifier()} else: - attributes = {'gene', 'trid'} - if context == 'transcript': - attributes.update({k for gene in self for transcript in gene.transcripts for k in transcript.keys() if k.isidentifier()}) - elif context == 'reference': - attributes.update({k for gene in self if gene.is_annotated for transcript in gene.ref_transcripts for k in transcript.keys() if k.isidentifier()}) + attributes = {"gene", "trid"} + if context == "transcript": + attributes.update( + { + k + for gene in self + for transcript in gene.transcripts + for k in transcript.keys() + if k.isidentifier() + } + ) + elif context == "reference": + attributes.update( + { + k + for gene in self + if gene.is_annotated + for transcript in gene.ref_transcripts + for k in transcript.keys() + if k.isidentifier() + } + ) # test whether the expression can be evaluated try: @@ -185,21 +292,33 @@ def add_filter(self, tag, expression, context='transcript', update=False): # _=f() # this would fail for many default expressions - can be avoided by checking if used attributes are None - but not ideal # Could be extended by dummy gene/transcript argument except BaseException: - logger.error('expression cannot be evaluated:\n%s', expression) + logger.error("expression cannot be evaluated:\n%s", expression) raise unknown_attr = [attr for attr in f_args if attr not in attributes] if unknown_attr: - logger.warning(f"Some attributes not present in {context} context, please make sure there is no typo: {','.join(unknown_attr)}\n\ - \rThis can happen for correct filters when there are no or only a few transcripts loaded into the model.") + logger.warning( + f"Some attributes not present in {context} context, please make sure there is no typo: {','.join(unknown_attr)}\n\ + \rThis can happen for correct filters when there are no or only a few transcripts loaded into the model." + ) if update: # avoid the same tag in different context for old_context, filter_dict in self.filter.items(): if filter_dict.pop(tag, None) is not None: - logger.info('replaced existing filter rule %s in %s context', tag, old_context) + logger.info( + "replaced existing filter rule %s in %s context", tag, old_context + ) self.filter[context][tag] = expression -def iter_genes(self: 'Transcriptome', region=None, query=None, min_coverage=None, max_coverage=None, gois=None, progress_bar=False): - '''Iterates over the genes of a region, optionally applying filters. +def iter_genes( + self: "Transcriptome", + region=None, + query=None, + min_coverage=None, + max_coverage=None, + gois=None, + progress_bar=False, +): + """Iterates over the genes of a region, optionally applying filters. :param region: The region to be considered. Either a string "chr:start-end", or a tuple (chr, start, end). Start and end is optional. If omitted, the complete genome is searched. @@ -208,22 +327,27 @@ def iter_genes(self: 'Transcriptome', region=None, query=None, min_coverage=None :param max_coverage: The maximum coverage threshold. Genes with more reads in total are ignored. :param gois: If provided, only a collection of genes of interest are considered, either gene ids or gene names. By default, all the genes are considered. - :param progress_bar: If set True, the progress bar is shown.''' + :param progress_bar: If set True, the progress bar is shown.""" if query: query_fun, used_tags = _filter_function(query) # used_tags={tag for tag in re.findall(r'\b\w+\b', query) if tag not in BOOL_OP} - all_filter = list(self.filter['gene']) - msg = 'did not find the following filter rules: {}\nvalid rules are: {}' + all_filter = list(self.filter["gene"]) + msg = "did not find the following filter rules: {}\nvalid rules are: {}" assert all(f in all_filter for f in used_tags), msg.format( - ', '.join(f for f in used_tags if f not in all_filter), ', '.join(all_filter)) - filter_fun = {tag: _filter_function(tag, self.filter['gene'])[0] for tag in used_tags} + ", ".join(f for f in used_tags if f not in all_filter), + ", ".join(all_filter), + ) + filter_fun = { + tag: _filter_function(tag, self.filter["gene"])[0] for tag in used_tags + } - try: # test the filter expression with dummy tags + # test the filter expression with dummy tags + try: query_fun(**{tag: True for tag in used_tags}) - except BaseException: - logger.error("Error in query string: \n{query}") - raise + except Exception as e: + logger.error("Error in query string: \n%s", query) + raise e if region is None: if gois is None: @@ -237,32 +361,47 @@ def iter_genes(self: 'Transcriptome', region=None, query=None, min_coverage=None start = None else: # parse region string (chr:start-end) try: - chrom, pos = region.split(':') - start, end = [int(i) for i in pos.split('-')] - except BaseException as e: - raise ValueError('incorrect region {} - specify as string "chr" or "chr:start-end" or tuple ("chr",start,end)'.format(region)) from e + chrom, pos = region.split(":") + start, end = [int(i) for i in pos.split("-")] + except Exception as e: + raise ValueError( + 'incorrect region {} - specify as string "chr" or "chr:start-end" or tuple ("chr",start,end)'.format( + region + ) + ) from e elif isinstance(region, tuple): chrom, start, end = region if start is not None: if chrom in self.data: - genes = self.data[chrom][int(start):int(end)] + genes = self.data[chrom][int(start) : int(end)] else: - raise ValueError('specified chromosome {} not found'.format(chrom)) + raise ValueError("specified chromosome {} not found".format(chrom)) if gois is not None: genes = [gene for gene in genes if gene.id in gois or gene.name in gois] # often some genes take much longer than others - smoothing 0 means avg - for gene in tqdm(genes, disable=not progress_bar, unit='genes', smoothing=0): + for gene in tqdm(genes, disable=not progress_bar, unit="genes", smoothing=0): if min_coverage is not None and gene.coverage.sum() < min_coverage: continue if max_coverage is not None and gene.coverage.sum() > max_coverage: continue - if query is None or query_fun(**{tag: fun(**gene.data) for tag, fun in filter_fun.items()}): + if query is None or query_fun( + **{tag: fun(**gene.data) for tag, fun in filter_fun.items()} + ): yield gene -def iter_transcripts(self: 'Transcriptome', region=None, query=None, min_coverage=None, max_coverage=None, genewise=False, gois=None, progress_bar=False): - '''Iterates over the transcripts of a region, optionally applying filters. +def iter_transcripts( + self: "Transcriptome", + region=None, + query=None, + min_coverage=None, + max_coverage=None, + genewise=False, + gois=None, + progress_bar=False, +): + """Iterates over the transcripts of a region, optionally applying filters. By default, each iteration returns a 3 Tuple with the gene object, the transcript number and the transcript dictionary. @@ -274,17 +413,27 @@ def iter_transcripts(self: 'Transcriptome', region=None, query=None, min_coverag :param genewise: In each iteration, return the gene and all transcript numbers and transcript dicts for the gene as tuples. :param gois: If provided, only transcripts from the list of genes of interest are considered. Provide as a list of gene ids or gene names. By default, all the genes are considered. - :param progress_bar: Print a progress bar. ''' + :param progress_bar: Print a progress bar.""" if query: # used_tags={tag for tag in re.findall(r'\b\w+\b', query) if tag not in BOOL_OP} - all_filter = list(self.filter['transcript']) + list(self.filter['gene']) + all_filter = list(self.filter["transcript"]) + list(self.filter["gene"]) query_fun, used_tags = _filter_function(query) - msg = 'did not find the following filter rules: {}\nvalid rules are: {}' + msg = "did not find the following filter rules: {}\nvalid rules are: {}" assert all(f in all_filter for f in used_tags), msg.format( - ', '.join(f for f in used_tags if f not in all_filter), ', '.join(all_filter)) - transcript_filter_fun = {tag: _filter_function(tag, self.filter['transcript'])[0] for tag in used_tags if tag in self.filter['transcript']} - gene_filter_fun = {tag: _filter_function(tag, self.filter['gene'])[0] for tag in used_tags if tag in self.filter['gene']} + ", ".join(f for f in used_tags if f not in all_filter), + ", ".join(all_filter), + ) + transcript_filter_fun = { + tag: _filter_function(tag, self.filter["transcript"])[0] + for tag in used_tags + if tag in self.filter["transcript"] + } + gene_filter_fun = { + tag: _filter_function(tag, self.filter["gene"])[0] + for tag in used_tags + if tag in self.filter["gene"] + } # test the filter expression with dummy tags try: @@ -297,21 +446,54 @@ def iter_transcripts(self: 'Transcriptome', region=None, query=None, min_coverag gene_filter_fun = {} if genewise: - for gene in self.iter_genes(region=region, gois=gois, progress_bar=progress_bar): - gene_filter_eval = {tag: fun(**gene.data) for tag, fun in gene_filter_fun.items()} - filter_result = tuple(_filter_transcripts(gene, gene.transcripts, query_fun, transcript_filter_fun, gene_filter_eval, min_coverage, max_coverage)) + for gene in self.iter_genes( + region=region, gois=gois, progress_bar=progress_bar + ): + gene_filter_eval = { + tag: fun(**gene.data) for tag, fun in gene_filter_fun.items() + } + filter_result = tuple( + _filter_transcripts( + gene, + gene.transcripts, + query_fun, + transcript_filter_fun, + gene_filter_eval, + min_coverage, + max_coverage, + ) + ) if filter_result: i_tuple, transcript_tuple = zip(*filter_result) yield gene, i_tuple, transcript_tuple else: - for gene in self.iter_genes(region=region, gois=gois, progress_bar=progress_bar): - gene_filter_eval = {tag: fun(**gene.data) for tag, fun in gene_filter_fun.items()} - for i, transcript in _filter_transcripts(gene, gene.transcripts, query_fun, transcript_filter_fun, gene_filter_eval, min_coverage, max_coverage): + for gene in self.iter_genes( + region=region, gois=gois, progress_bar=progress_bar + ): + gene_filter_eval = { + tag: fun(**gene.data) for tag, fun in gene_filter_fun.items() + } + for i, transcript in _filter_transcripts( + gene, + gene.transcripts, + query_fun, + transcript_filter_fun, + gene_filter_eval, + min_coverage, + max_coverage, + ): yield gene, i, transcript -def iter_ref_transcripts(self: 'Transcriptome', region=None, query=None, genewise=False, gois=None, progress_bar=False): - '''Iterates over the referemce transcripts of a region, optionally applying filters. +def iter_ref_transcripts( + self: "Transcriptome", + region=None, + query=None, + genewise=False, + gois=None, + progress_bar=False, +): + """Iterates over the referemce transcripts of a region, optionally applying filters. :param region: The region to be considered. Either a string "chr:start-end", or a tuple (chr,start,end). Start and end is optional. If omitted, the complete genome is searched. @@ -320,17 +502,27 @@ def iter_ref_transcripts(self: 'Transcriptome', region=None, query=None, genewis :param genewise: In each iteration, return the gene and all transcript numbers and transcript dicts for the gene as tuples. :param gois: If provided, only transcripts from the list of genes of interest are considered. Provide as a list of gene ids or gene names. By default, all the genes are considered. - :param progress_bar: Print a progress bar. ''' + :param progress_bar: Print a progress bar.""" if query: # used_tags={tag for tag in re.findall(r'\b\w+\b', query) if tag not in BOOL_OP} - all_filter = list(self.filter['reference']) + list(self.filter['gene']) + all_filter = list(self.filter["reference"]) + list(self.filter["gene"]) query_fun, used_tags = _filter_function(query) - msg = 'did not find the following filter rules: {}\nvalid rules are: {}' - ref_filter_fun = {tag: _filter_function(tag, self.filter['reference'])[0] for tag in used_tags if tag in self.filter['reference']} - gene_filter_fun = {tag: _filter_function(tag, self.filter['gene'])[0] for tag in used_tags if tag in self.filter['gene']} + msg = "did not find the following filter rules: {}\nvalid rules are: {}" + ref_filter_fun = { + tag: _filter_function(tag, self.filter["reference"])[0] + for tag in used_tags + if tag in self.filter["reference"] + } + gene_filter_fun = { + tag: _filter_function(tag, self.filter["gene"])[0] + for tag in used_tags + if tag in self.filter["gene"] + } assert all(f in all_filter for f in used_tags), msg.format( - ', '.join(f for f in used_tags if f not in all_filter), ', '.join(all_filter)) + ", ".join(f for f in used_tags if f not in all_filter), + ", ".join(all_filter), + ) try: # test the filter expression with dummy tags _ = query_fun(**{tag: True for tag in used_tags}) except BaseException: @@ -340,38 +532,74 @@ def iter_ref_transcripts(self: 'Transcriptome', region=None, query=None, genewis ref_filter_fun = query_fun = None gene_filter_fun = {} if genewise: - for gene in self.iter_genes(region=region, gois=gois, progress_bar=progress_bar): - gene_filter_eval = {tag: fun(**gene.data) for tag, fun in gene_filter_fun.items()} - filter_result = tuple(_filter_transcripts(gene, gene.ref_transcripts, query_fun, ref_filter_fun, gene_filter_eval)) + for gene in self.iter_genes( + region=region, gois=gois, progress_bar=progress_bar + ): + gene_filter_eval = { + tag: fun(**gene.data) for tag, fun in gene_filter_fun.items() + } + filter_result = tuple( + _filter_transcripts( + gene, + gene.ref_transcripts, + query_fun, + ref_filter_fun, + gene_filter_eval, + ) + ) if filter_result: i_tuple, transcript_tuple = zip(*filter_result) yield gene, i_tuple, transcript_tuple else: - for gene in self.iter_genes(region=region, gois=gois, progress_bar=progress_bar): + for gene in self.iter_genes( + region=region, gois=gois, progress_bar=progress_bar + ): if gene.is_annotated: - gene_filter_eval = {tag: fun(**gene.data) for tag, fun in gene_filter_fun.items()} - for i, transcript in _filter_transcripts(gene, gene.ref_transcripts, query_fun, ref_filter_fun, gene_filter_eval): + gene_filter_eval = { + tag: fun(**gene.data) for tag, fun in gene_filter_fun.items() + } + for i, transcript in _filter_transcripts( + gene, + gene.ref_transcripts, + query_fun, + ref_filter_fun, + gene_filter_eval, + ): yield gene, i, transcript def _eval_filter_fun(fun, name, **args): - '''Decorator for the filter functions, which are lambdas and thus cannot have normal decorators. - On exceptions the provided parameters are reported. This is helpful for debugging.''' + """Decorator for the filter functions, which are lambdas and thus cannot have normal decorators. + On exceptions the provided parameters are reported. This is helpful for debugging. + """ try: return fun(**args) except Exception as e: - logger.error('error when evaluating filter %s with arguments %s: %s', name, str(args), str(e)) + logger.error( + "error when evaluating filter %s with arguments %s: %s", + name, + str(args), + str(e), + ) raise # either stop evaluation # return False #or continue -def _filter_transcripts(gene: 'Gene', transcripts, query_fun, filter_fun, g_filter_eval, mincoverage=None, maxcoverage=None): - ''' Iterator over the transcripts of the gene. +def _filter_transcripts( + gene: "Gene", + transcripts, + query_fun, + filter_fun, + g_filter_eval, + mincoverage=None, + maxcoverage=None, +): + """Iterator over the transcripts of the gene. Transcrips are specified by lists of flags submitted to the parameters. :param query_fun: function to be evaluated on tags - :param filter_fun: tags to be evalutated on transcripts''' + :param filter_fun: tags to be evalutated on transcripts""" for i, transcript in enumerate(transcripts): if mincoverage and gene.coverage[:, i].sum() < mincoverage: continue @@ -379,6 +607,11 @@ def _filter_transcripts(gene: 'Gene', transcripts, query_fun, filter_fun, g_filt continue filter_transcript = transcript.copy() query_result = query_fun is None or query_fun( - **g_filter_eval, **{tag: _eval_filter_fun(f, tag, gene=gene, trid=i, **filter_transcript) for tag, f in filter_fun.items()}) + **g_filter_eval, + **{ + tag: _eval_filter_fun(f, tag, gene=gene, trid=i, **filter_transcript) + for tag, f in filter_fun.items() + }, + ) if query_result: yield i, transcript diff --git a/src/isotools/_transcriptome_io.py b/src/isotools/_transcriptome_io.py index 4cb2dca..7f0d7d1 100644 --- a/src/isotools/_transcriptome_io.py +++ b/src/isotools/_transcriptome_io.py @@ -1,5 +1,6 @@ from __future__ import annotations import numpy as np + # from numpy.lib.function_base import percentile, quantile import pandas as pd from os import path @@ -11,8 +12,20 @@ from contextlib import ExitStack from .short_read import Coverage from typing import Tuple, TYPE_CHECKING -from ._utils import junctions_from_cigar, splice_identical, is_same_gene, has_overlap, get_overlap, pairwise, \ - cigar_string2tuples, rc, get_intersects, _find_splice_sites, _get_overlap, get_quantiles # , _get_exonic_region +from ._utils import ( + junctions_from_cigar, + splice_identical, + is_same_gene, + has_overlap, + get_overlap, + pairwise, + cigar_string2tuples, + rc, + get_intersects, + _find_splice_sites, + _get_overlap, + get_quantiles, +) # , _get_exonic_region from .gene import Gene, Transcript from .decorators import experimental import logging @@ -23,78 +36,114 @@ if TYPE_CHECKING: from .transcriptome import Transcriptome -logger = logging.getLogger('isotools') +logger = logging.getLogger("isotools") # io functions for the transcriptome class def add_short_read_coverage(self: Transcriptome, bam_files, load=False): - '''Adds short read coverage to the genes. + """Adds short read coverage to the genes. By default (e.g. if load==False), this method does not actually read the bams, but import for each gene is done at first access. :param bam_files: A dict with the sample names as keys, and the path to aligned short reads in bam format as values. - :param load: If True, the coverage of all genes is imported. WARNING: this may take a long time.''' - self.infos.setdefault('short_reads', pd.DataFrame(columns=['name', 'file'], dtype='object')) - - bam_files = {k: v for k, v in bam_files.items() if k not in self.infos['short_reads']['name']} - self.infos['short_reads'] = pd.concat([self.infos['short_reads'], - pd.DataFrame({'name': bam_files.keys(), 'file': bam_files.values()})], - ignore_index=True) - if load: # when loading coverage for all genes keep the filehandle open, hopefully a bit faster - for i, bamfile in enumerate(self.infos['short_reads'].file): - logger.info('Adding short read coverag from %s', bamfile) + :param load: If True, the coverage of all genes is imported. WARNING: this may take a long time. + """ + self.infos.setdefault( + "short_reads", pd.DataFrame(columns=["name", "file"], dtype="object") + ) + + bam_files = { + k: v for k, v in bam_files.items() if k not in self.infos["short_reads"]["name"] + } + self.infos["short_reads"] = pd.concat( + [ + self.infos["short_reads"], + pd.DataFrame({"name": bam_files.keys(), "file": bam_files.values()}), + ], + ignore_index=True, + ) + if ( + load + ): # when loading coverage for all genes keep the filehandle open, hopefully a bit faster + for i, bamfile in enumerate(self.infos["short_reads"].file): + logger.info("Adding short read coverag from %s", bamfile) with AlignmentFile(bamfile, "rb") as align: for gene in tqdm(self): - gene.data.setdefault('short_reads', list()) - if len(gene.data['short_reads']) == i: - gene.data['short_reads'].append(Coverage.from_alignment(align, gene)) + gene.data.setdefault("short_reads", list()) + if len(gene.data["short_reads"]) == i: + gene.data["short_reads"].append( + Coverage.from_alignment(align, gene) + ) def remove_short_read_coverage(self: Transcriptome): - '''Removes short read coverage. + """Removes short read coverage. - Removes all short read coverage information from self.''' + Removes all short read coverage information from self.""" - if 'short_reads' in self.infos: - del self.infos['short_reads'] + if "short_reads" in self.infos: + del self.infos["short_reads"] for gene in self: - if 'short_reads' in gene: - del self.data['short_reads'] + if "short_reads" in gene: + del self.data["short_reads"] else: - logger.warning('No short read coverage to remove') + logger.warning("No short read coverage to remove") @experimental def remove_samples(self: Transcriptome, sample_names): - ''' Removes samples from the dataset. + """Removes samples from the dataset. - :params sample_names: A list of sample names to remove.''' + :params sample_names: A list of sample names to remove.""" if isinstance(sample_names, str): sample_names = [sample_names] - assert all(s in self.samples for s in sample_names), 'Did not find all samples to remvoe in dataset' + assert all( + s in self.samples for s in sample_names + ), "Did not find all samples to remvoe in dataset" sample_table = self.sample_table rm_idx = sample_table.index[sample_table.name.isin(sample_names)] sample_table = sample_table.drop(index=sample_table.index[rm_idx]) for gene in self: remove_transcript_ids = [] for i, transcript in enumerate(gene.transcripts): - if any(s in transcript['coverage'] for s in sample_names): - transcript['coverage'] = {s: cov for s, cov in transcript['coverage'].items() if s not in sample_names} - if not transcript['coverage']: + if any(s in transcript["coverage"] for s in sample_names): + transcript["coverage"] = { + s: cov + for s, cov in transcript["coverage"].items() + if s not in sample_names + } + if not transcript["coverage"]: remove_transcript_ids.append(i) - if remove_transcript_ids: # remove the transcripts that is not expressed by remaining samples - gene.data['transcripts'] = [transcript for i, transcript in enumerate(gene.transcripts) if i not in remove_transcript_ids] - gene.data['segment_graph'] = None # gets recomputed on next request - gene.data['coverage'] = None - - -def add_sample_from_csv(self: Transcriptome, coverage_csv_file, transcripts_file, transcript_id_col=None, sample_cov_cols=None, - sample_properties=None, add_chromosomes=True, infer_genes=False, reconstruct_genes=True, fuzzy_junction=0, - min_exonic_ref_coverage=.25, sep='\t'): - '''Imports expressed transcripts from coverage table and gtf/gff file, and adds it to the 'Transcriptome' object. + if ( + remove_transcript_ids + ): # remove the transcripts that is not expressed by remaining samples + gene.data["transcripts"] = [ + transcript + for i, transcript in enumerate(gene.transcripts) + if i not in remove_transcript_ids + ] + gene.data["segment_graph"] = None # gets recomputed on next request + gene.data["coverage"] = None + + +def add_sample_from_csv( + self: Transcriptome, + coverage_csv_file, + transcripts_file, + transcript_id_col=None, + sample_cov_cols=None, + sample_properties=None, + add_chromosomes=True, + infer_genes=False, + reconstruct_genes=True, + fuzzy_junction=0, + min_exonic_ref_coverage=0.25, + sep="\t", +): + """Imports expressed transcripts from coverage table and gtf/gff file, and adds it to the 'Transcriptome' object. Transcript to gene assignment is either taken from the transcript_file, or recreated, as specified by the reconstruct_genes parameter. @@ -122,48 +171,72 @@ def add_sample_from_csv(self: Transcriptome, coverage_csv_file, transcripts_file :param progress_bar: Show the progress. :param sep: Specify the seperator for the coverage_csv_file. :return: Dict with map of renamed gene ids. -''' + """ cov_tab = pd.read_csv(coverage_csv_file, sep=sep) if sample_cov_cols is None: - sample_cov_cols = {c.replace('_coverage', ''): c for c in cov_tab.columns if '_coverage' in c} + sample_cov_cols = { + c.replace("_coverage", ""): c for c in cov_tab.columns if "_coverage" in c + } else: - assert all(c in cov_tab for c in sample_cov_cols.values()), 'coverage cols missing in %s: %s' % ( - coverage_csv_file, ', '.join(c for c in sample_cov_cols.values() if c not in cov_tab)) + assert all( + c in cov_tab for c in sample_cov_cols.values() + ), "coverage cols missing in %s: %s" % ( + coverage_csv_file, + ", ".join(c for c in sample_cov_cols.values() if c not in cov_tab), + ) samples = list(sample_cov_cols) if transcript_id_col is None: - if 'transcript_id' in cov_tab.columns: + if "transcript_id" in cov_tab.columns: pass - elif 'gene_id' in cov_tab.columns and 'transcript_nr' in cov_tab.columns: - cov_tab['transcript_id'] = cov_tab.gene_id+'_'+cov_tab.transcript_nr.astype(str) + elif "gene_id" in cov_tab.columns and "transcript_nr" in cov_tab.columns: + cov_tab["transcript_id"] = ( + cov_tab.gene_id + "_" + cov_tab.transcript_nr.astype(str) + ) else: - raise ValueError('"transcript_id_col" not specified, and coverage table does not contain "transcript_id", nor "gene_id" and "transcript_nr"') + raise ValueError( + '"transcript_id_col" not specified, and coverage table does not contain "transcript_id", nor "gene_id" and "transcript_nr"' + ) elif isinstance(transcript_id_col, list): - assert all(c in cov_tab for c in transcript_id_col), 'missing specified transcript_id_col' - cov_tab['transcript_id'] = ['_'.join(str(v) for v in row) for _, row in cov_tab[transcript_id_col].iterrows()] + assert all( + c in cov_tab for c in transcript_id_col + ), "missing specified transcript_id_col" + cov_tab["transcript_id"] = [ + "_".join(str(v) for v in row) + for _, row in cov_tab[transcript_id_col].iterrows() + ] else: - assert transcript_id_col in cov_tab, 'missing specified transcript_id_col' - cov_tab['transcript_id'] = cov_tab[transcript_id_col] + assert transcript_id_col in cov_tab, "missing specified transcript_id_col" + cov_tab["transcript_id"] = cov_tab[transcript_id_col] # could be optimized, but code is easier when the id column always is transcript_id - transcript_id_col = 'transcript_id' + transcript_id_col = "transcript_id" known_sa = set(samples).intersection(self.samples) - assert not known_sa, 'Attempt to add known samples: %s' % known_sa + assert not known_sa, "Attempt to add known samples: %s" % known_sa # cov_tab.set_index('transcript_id') # assert cov_tab.index.is_unique, 'ambigous transcript ids in %s' % coverage_csv_file # check sample properties if sample_properties is None: - sample_properties = {sample: {'group': sample} for sample in samples} + sample_properties = {sample: {"group": sample} for sample in samples} elif isinstance(sample_properties, pd.DataFrame): - if 'name' in sample_properties: - sample_properties = sample_properties.set_index('name') - sample_properties = {sample: {k: v for k, v in row.items() if k not in {'file', 'nonchimeric_reads', 'chimeric_reads'}} - for sample, row in sample_properties.iterrows()} - assert all(sample in sample_properties for sample in samples), 'missing sample_properties for samples %s' % ', '.join( - (sample for sample in samples if sample not in sample_properties)) + if "name" in sample_properties: + sample_properties = sample_properties.set_index("name") + sample_properties = { + sample: { + k: v + for k, v in row.items() + if k not in {"file", "nonchimeric_reads", "chimeric_reads"} + } + for sample, row in sample_properties.iterrows() + } + assert all( + sample in sample_properties for sample in samples + ), "missing sample_properties for samples %s" % ", ".join( + (sample for sample in samples if sample not in sample_properties) + ) for sample in sample_properties: - sample_properties[sample].setdefault('group', sample) + sample_properties[sample].setdefault("group", sample) logger.info('adding samples "%s" from csv', '", "'.join(samples)) # consider chromosomes not in the reference? @@ -172,37 +245,46 @@ def add_sample_from_csv(self: Transcriptome, coverage_csv_file, transcripts_file else: chromosomes = self.chromosomes - logger.info('importing transcripts from %s. Please note transcripts with missing annotations will be skipped.', transcripts_file) - file_format = path.splitext(transcripts_file)[1].lstrip('.') - if file_format == 'gz': - file_format = path.splitext(transcripts_file[:-3])[1].lstrip('.') - if file_format == 'gtf': - exons, transcripts, gene_infos, cds_start, cds_stop, skipped = _read_gtf_file(transcripts_file, chromosomes=chromosomes, infer_genes=infer_genes) - elif file_format in ('gff', 'gff3'): # gff/gff3 - exons, transcripts, gene_infos, cds_start, cds_stop, skipped = _read_gff_file(transcripts_file, chromosomes=chromosomes, infer_genes=infer_genes) + logger.info( + "importing transcripts from %s. Please note transcripts with missing annotations will be skipped.", + transcripts_file, + ) + file_format = path.splitext(transcripts_file)[1].lstrip(".") + if file_format == "gz": + file_format = path.splitext(transcripts_file[:-3])[1].lstrip(".") + if file_format == "gtf": + exons, transcripts, gene_infos, cds_start, cds_stop, skipped = _read_gtf_file( + transcripts_file, chromosomes=chromosomes, infer_genes=infer_genes + ) + elif file_format in ("gff", "gff3"): # gff/gff3 + exons, transcripts, gene_infos, cds_start, cds_stop, skipped = _read_gff_file( + transcripts_file, chromosomes=chromosomes, infer_genes=infer_genes + ) else: - logger.warning('unknown file format %s of the transcriptome file', file_format) + logger.warning("unknown file format %s of the transcriptome file", file_format) - logger.debug('sorting exon positions...') + logger.debug("sorting exon positions...") for tid in exons: exons[tid].sort() - logger.info('adding coverage information for transcripts imported.') + logger.info("adding coverage information for transcripts imported.") if skipped: - cov_tab = cov_tab[~cov_tab['transcript_id'].isin(skipped['transcript'])] + cov_tab = cov_tab[~cov_tab["transcript_id"].isin(skipped["transcript"])] - if 'gene_id' not in cov_tab: + if "gene_id" not in cov_tab: gene_id_dict = {tid: gid for gid, tids in transcripts.items() for tid in tids} try: - cov_tab['gene_id'] = [gene_id_dict[tid] for tid in cov_tab.transcript_id] + cov_tab["gene_id"] = [gene_id_dict[tid] for tid in cov_tab.transcript_id] except KeyError as e: - logger.warning('transcript_id %s from csv file not found in gtf.' % e.args[0]) - if 'chr' not in cov_tab: + logger.warning( + "transcript_id %s from csv file not found in gtf." % e.args[0] + ) + if "chr" not in cov_tab: chrom_dict = {gid: chrom for chrom, gids in gene_infos.items() for gid in gids} try: - cov_tab['chr'] = [chrom_dict[gid] for gid in cov_tab.gene_id] + cov_tab["chr"] = [chrom_dict[gid] for gid in cov_tab.gene_id] except KeyError as e: - logger.warning('gene_id %s from csv file not found in gtf.', e.args[0]) + logger.warning("gene_id %s from csv file not found in gtf.", e.args[0]) used_transcripts = set() for _, row in cov_tab.iterrows(): @@ -213,21 +295,50 @@ def add_sample_from_csv(self: Transcriptome, coverage_csv_file, transcripts_file try: assert row.transcript_id not in used_transcripts transcript = transcripts[row.gene_id][row.transcript_id] - transcript['transcript_id'] = row.transcript_id - transcript['exons'] = sorted([list(e) for e in exons[row.transcript_id]]) # needs to be mutable - transcript['coverage'] = {sample: row[sample_cov_cols[sample]] for sample in samples if row[sample_cov_cols[sample]] > 0} - transcript['strand'] = gene_infos[row.chr][row.gene_id][0]['strand'] # gene_infos is a 3 tuple (info, start, end) - transcript['TSS'] = {sample: {transcript['exons'][0][0]: row[sample_cov_cols[sample]]} for sample in samples if row[sample_cov_cols[sample]] > 0} - transcript['PAS'] = {sample: {transcript['exons'][-1][1]: row[sample_cov_cols[sample]]} for sample in samples if row[sample_cov_cols[sample]] > 0} - if transcript['strand'] == '-': - transcript['TSS'], transcript['PAS'] = transcript['PAS'], transcript['TSS'] + transcript["transcript_id"] = row.transcript_id + transcript["exons"] = sorted( + [list(e) for e in exons[row.transcript_id]] + ) # needs to be mutable + transcript["coverage"] = { + sample: row[sample_cov_cols[sample]] + for sample in samples + if row[sample_cov_cols[sample]] > 0 + } + transcript["strand"] = gene_infos[row.chr][row.gene_id][0][ + "strand" + ] # gene_infos is a 3 tuple (info, start, end) + transcript["TSS"] = { + sample: {transcript["exons"][0][0]: row[sample_cov_cols[sample]]} + for sample in samples + if row[sample_cov_cols[sample]] > 0 + } + transcript["PAS"] = { + sample: {transcript["exons"][-1][1]: row[sample_cov_cols[sample]]} + for sample in samples + if row[sample_cov_cols[sample]] > 0 + } + if transcript["strand"] == "-": + transcript["TSS"], transcript["PAS"] = ( + transcript["PAS"], + transcript["TSS"], + ) used_transcripts.add(row.transcript_id) except KeyError as e: - logger.warning('skipping transcript %s from gene %s, missing infos in gtf: %s', row.transcript_id, row.gene_id, e.args[0]) + logger.warning( + "skipping transcript %s from gene %s, missing infos in gtf: %s", + row.transcript_id, + row.gene_id, + e.args[0], + ) except AssertionError as e: - logger.warning('skipping transcript %s from gene %s: duplicate transcript id; Error: %s', row.transcript_id, row.gene_id, e) + logger.warning( + "skipping transcript %s from gene %s: duplicate transcript id; Error: %s", + row.transcript_id, + row.gene_id, + e, + ) id_map = {} - novel_prefix = 'IT_novel_' + novel_prefix = "IT_novel_" if reconstruct_genes: # this approach ignores gene structure, and reassigns transcripts novel = {} @@ -235,18 +346,28 @@ def add_sample_from_csv(self: Transcriptome, coverage_csv_file, transcripts_file if row.transcript_id not in used_transcripts: continue transcript = transcripts[row.gene_id][row.transcript_id] - gene = _add_sample_transcript(self, transcript, row.chr, fuzzy_junction, min_exonic_ref_coverage) + gene = _add_sample_transcript( + self, transcript, row.chr, fuzzy_junction, min_exonic_ref_coverage + ) if gene is None: - transcript['_original_ids'] = (row.gene_id, row.transcript_id) - novel.setdefault(row.chr, []).append(Interval(transcript['exons'][0][0], transcript['exons'][-1][1], transcript)) + transcript["_original_ids"] = (row.gene_id, row.transcript_id) + novel.setdefault(row.chr, []).append( + Interval( + transcript["exons"][0][0], + transcript["exons"][-1][1], + transcript, + ) + ) elif gene.id != row.gene_id: id_map.setdefault(row.gene_id, {})[row.transcript_id] = gene.id for chrom in novel: - novel_genes = _add_novel_genes(self, IntervalTree(novel[chrom]), chrom, gene_prefix=novel_prefix) + novel_genes = _add_novel_genes( + self, IntervalTree(novel[chrom]), chrom, gene_prefix=novel_prefix + ) for novel_g in novel_genes: for novel_transcript in novel_g.transcripts: - import_id = novel_transcript.pop('_original_ids') + import_id = novel_transcript.pop("_original_ids") if novel_g.id != import_id[0]: id_map.setdefault(import_id[0], {})[import_id[1]] = novel_g.id else: @@ -254,26 +375,56 @@ def add_sample_from_csv(self: Transcriptome, coverage_csv_file, transcripts_file for chrom in gene_infos: for gid, (g, start, end) in gene_infos[chrom].items(): # only transcripts with coverage - import_id = g['ID'] - transcript_list = [transcript for transcript_id, transcript in transcripts[gid].items() if transcript_id in used_transcripts] + import_id = g["ID"] + transcript_list = [ + transcript + for transcript_id, transcript in transcripts[gid].items() + if transcript_id in used_transcripts + ] # find best matching overlapping ref gene - gene = _add_sample_gene(self, start, end, g, transcript_list, chrom, novel_prefix) + gene = _add_sample_gene( + self, start, end, g, transcript_list, chrom, novel_prefix + ) if import_id != gene.id: id_map[import_id] = gene.id # todo: extend sample_table for sample in samples: - sample_properties[sample].update({'name': sample, 'file': coverage_csv_file, 'nonchimeric_reads': cov_tab[sample_cov_cols[sample]].sum(), 'chimeric_reads': 0}) + sample_properties[sample].update( + { + "name": sample, + "file": coverage_csv_file, + "nonchimeric_reads": cov_tab[sample_cov_cols[sample]].sum(), + "chimeric_reads": 0, + } + ) # self.infos['sample_table'] = self.sample_table.append(sample_properties[sample], ignore_index=True) - self.infos['sample_table'] = pd.concat([self.sample_table, pd.DataFrame([sample_properties[sample]])], ignore_index=True) + self.infos["sample_table"] = pd.concat( + [self.sample_table, pd.DataFrame([sample_properties[sample]])], + ignore_index=True, + ) self.make_index() return id_map -def add_sample_from_bam(self: Transcriptome, fn, sample_name=None, barcode_file=None, fuzzy_junction=5, add_chromosomes=True, min_mapqual=0, - min_align_fraction=.75, chimeric_mincov=2, min_exonic_ref_coverage=.25, use_satag=False, save_readnames=False, progress_bar=True, - strictness=math.inf, **kwargs): - '''Imports expressed transcripts from bam and adds it to the 'Transcriptome' object. +def add_sample_from_bam( + self: Transcriptome, + fn, + sample_name=None, + barcode_file=None, + fuzzy_junction=5, + add_chromosomes=True, + min_mapqual=0, + min_align_fraction=0.75, + chimeric_mincov=2, + min_exonic_ref_coverage=0.25, + use_satag=False, + save_readnames=False, + progress_bar=True, + strictness=math.inf, + **kwargs, +): + """Imports expressed transcripts from bam and adds it to the 'Transcriptome' object. :param fn: The bam filename of the new sample :param sample_name: Name of the new sample. If specified, all reads are assumed to belong to this sample. @@ -295,28 +446,42 @@ def add_sample_from_bam(self: Transcriptome, fn, sample_name=None, barcode_file= :param save_readnames: Save a list of the readnames, that contributed to the transcript. :param progress_bar: Show the progress. :param strictness: Number of bp that two transcripts are allowed to differ for transcription start and end sites to be still considered one transcript. - :param kwargs: Additional keyword arguments are added to the sample table.''' + :param kwargs: Additional keyword arguments are added to the sample table.""" # todo: one alignment may contain several samples - this is not supported at the moment if barcode_file is None: - assert sample_name is not None, 'Neither sample_name nor barcode_file was specified.' - assert sample_name not in self.samples, 'sample %s is already in the data set.' % sample_name - logger.info('adding sample %s from file %s', sample_name, fn) + assert ( + sample_name is not None + ), "Neither sample_name nor barcode_file was specified." + assert sample_name not in self.samples, ( + "sample %s is already in the data set." % sample_name + ) + logger.info("adding sample %s from file %s", sample_name, fn) barcodes = {} else: # read the barcode file - barcodes = pd.read_csv(barcode_file, sep='\t', names=['bc', 'name'], index_col='bc')['name'] + barcodes = pd.read_csv( + barcode_file, sep="\t", names=["bc", "name"], index_col="bc" + )["name"] if sample_name is not None: - barcodes = barcodes.apply(lambda x: '{}_{}'.format(sample_name, x)) + barcodes = barcodes.apply(lambda x: "{}_{}".format(sample_name, x)) barcodes = barcodes.to_dict() - assert all(sample not in self.samples for sample in barcodes), \ - 'samples %s are already in the data set.' % ', '.join(sample for sample in barcodes if sample in self.samples) - logger.info('adding %s transcriptomes in %s groups as specified in %s from file %s', - len(set(barcodes.keys())), len(set(barcodes.values())), barcode_file, fn) + assert all( + sample not in self.samples for sample in barcodes + ), "samples %s are already in the data set." % ", ".join( + sample for sample in barcodes if sample in self.samples + ) + logger.info( + "adding %s transcriptomes in %s groups as specified in %s from file %s", + len(set(barcodes.keys())), + len(set(barcodes.values())), + barcode_file, + fn, + ) # add reverse complement barcodes.update({rc(k): v for k, v in barcodes.items()}) - kwargs['file'] = fn + kwargs["file"] = fn skip_bc = 0 partial_count = 0 # genome_fh=FastaFile(genome_fn) if genome_fn is not None else None @@ -335,20 +500,31 @@ def add_sample_from_bam(self: Transcriptome, fn, sample_name=None, barcode_file= total_nc_reads_chr = {} chimeric = dict() - with tqdm(total=total_alignments, unit='reads', unit_scale=True, disable=not progress_bar) as pbar: - - for chrom in chromosomes: # todo: potential issue here - secondary/chimeric alignments to non listed chromosomes are ignored + with tqdm( + total=total_alignments, + unit="reads", + unit_scale=True, + disable=not progress_bar, + ) as pbar: + + for ( + chrom + ) in ( + chromosomes + ): # todo: potential issue here - secondary/chimeric alignments to non listed chromosomes are ignored total_nc_reads_chr[chrom] = dict() pbar.set_postfix(chr=chrom) # transcripts=IntervalTree() # novel=IntervalTree() chr_len = align.get_reference_length(chrom) - transcript_intervals: IntervalArray[Interval] = IntervalArray(chr_len) # intervaltree was pretty slow for this context + transcript_intervals: IntervalArray[Interval] = IntervalArray( + chr_len + ) # intervaltree was pretty slow for this context novel = IntervalArray(chr_len) n_reads = 0 for read in align.fetch(chrom): n_reads += 1 - pbar.update(.5) + pbar.update(0.5) # unmapped if read.flag & 0x4: unmapped += 1 @@ -364,74 +540,123 @@ def add_sample_from_bam(self: Transcriptome, fn, sample_name=None, barcode_file= continue tags = dict(read.get_tags()) if barcodes: - if 'XC' not in tags or tags['XC'] not in barcodes: + if "XC" not in tags or tags["XC"] not in barcodes: skip_bc += 1 continue - s_name = sample_name if not barcodes else barcodes[tags['XC']] - strand = '-' if read.is_reverse else '+' + s_name = sample_name if not barcodes else barcodes[tags["XC"]] + strand = "-" if read.is_reverse else "+" exons = junctions_from_cigar(read.cigartuples, read.reference_start) transcript_range = (exons[0][0], exons[-1][1]) if transcript_range[0] < 0 or transcript_range[1] > chr_len: - logger.error('Alignment outside chromosome range: transcript at %s for chromosome %s of length %s', transcript_range, chrom, chr_len) + logger.error( + "Alignment outside chromosome range: transcript at %s for chromosome %s of length %s", + transcript_range, + chrom, + chr_len, + ) continue - if 'is' in tags: + if "is" in tags: # number of actual reads supporting this transcript - cov = tags['is'] + cov = tags["is"] else: cov = 1 # part of a chimeric alignment - if 'SA' in tags or read.flag & 0x800: + if "SA" in tags or read.flag & 0x800: # store it if it's part of a chimeric alignment and meets the minimum coverage threshold # otherwise ignore chimeric read if chimeric_mincov > 0: chimeric.setdefault(read.query_name, [{s_name: cov}, []]) - assert chimeric[read.query_name][0][s_name] == cov, \ - 'error in bam: parts of chimeric alignment for read %s has different coverage information %s != %s' % ( - read.query_name, chimeric[read.query_name][0], cov) - chimeric[read.query_name][1].append([chrom, strand, exons, aligned_part(read.cigartuples, read.is_reverse), None]) - if use_satag and 'SA' in tags: - for snd_align in (sa.split(',') for sa in tags['SA'].split(';') if sa): + assert chimeric[read.query_name][0][s_name] == cov, ( + "error in bam: parts of chimeric alignment for read %s has different coverage information %s != %s" + % (read.query_name, chimeric[read.query_name][0], cov) + ) + chimeric[read.query_name][1].append( + [ + chrom, + strand, + exons, + aligned_part(read.cigartuples, read.is_reverse), + None, + ] + ) + if use_satag and "SA" in tags: + for snd_align in ( + sa.split(",") for sa in tags["SA"].split(";") if sa + ): snd_cigartuples = cigar_string2tuples(snd_align[3]) - snd_exons = junctions_from_cigar(snd_cigartuples, int(snd_align[1])) + snd_exons = junctions_from_cigar( + snd_cigartuples, int(snd_align[1]) + ) chimeric[read.query_name][1].append( - [snd_align[0], snd_align[2], snd_exons, aligned_part(snd_cigartuples, snd_align[2] == '-'), None]) + [ + snd_align[0], + snd_align[2], + snd_exons, + aligned_part( + snd_cigartuples, snd_align[2] == "-" + ), + None, + ] + ) # logging.debug(chimeric[read.query_name]) continue # skipping low-quality alignments try: # if edit distance becomes large relative to read length, skip the alignment - if min_align_fraction > 0 and (1 - tags['NM'] / read.query_length) < min_align_fraction: + if ( + min_align_fraction > 0 + and (1 - tags["NM"] / read.query_length) + < min_align_fraction + ): partial_count += 1 continue except KeyError: - logging.warning('min_align_fraction set > 0 (%s), but reads found without "NM" tag. Setting min_align_fraction to 0', - min_align_fraction) + logging.warning( + 'min_align_fraction set > 0 (%s), but reads found without "NM" tag. Setting min_align_fraction to 0', + min_align_fraction, + ) min_align_fraction = 0 total_nc_reads_chr[chrom].setdefault(s_name, 0) total_nc_reads_chr[chrom][s_name] += cov # did we see this transcript already? - for transcript_interval in transcript_intervals.overlap(*transcript_range): - if transcript_interval.data['strand'] != strand: + for transcript_interval in transcript_intervals.overlap( + *transcript_range + ): + if transcript_interval.data["strand"] != strand: continue - if splice_identical(exons, transcript_interval.data['exons'], strictness=strictness): + if splice_identical( + exons, + transcript_interval.data["exons"], + strictness=strictness, + ): transcript = transcript_interval.data - transcript.setdefault('range', {}).setdefault(transcript_range, 0) - transcript['range'][transcript_range] += cov + transcript.setdefault("range", {}).setdefault( + transcript_range, 0 + ) + transcript["range"][transcript_range] += cov if save_readnames: - transcript['reads'].setdefault(s_name, []).append(read.query_name) + transcript["reads"].setdefault(s_name, []).append( + read.query_name + ) break else: - transcript = {'exons': exons, 'range': {transcript_range: cov}, 'strand': strand} + transcript = { + "exons": exons, + "range": {transcript_range: cov}, + "strand": strand, + } if barcodes: - transcript['bc_group'] = barcodes[tags['XC']] + transcript["bc_group"] = barcodes[tags["XC"]] if save_readnames: - transcript['reads'] = {s_name: [read.query_name]} - transcript_intervals.add(Interval(*transcript_range, transcript)) + transcript["reads"] = {s_name: [read.query_name]} + transcript_intervals.add( + Interval(*transcript_range, transcript) + ) # if genome_fh is not None: # mutations=get_mutations(read.cigartuples, read.query_sequence, genome_fh, chrom,read.reference_start,read.query_qualities) # for pos,ref,alt,qual in mutations: @@ -442,19 +667,28 @@ def add_sample_from_bam(self: Transcriptome, fn, sample_name=None, barcode_file= if 4 in read.cigartuples: # clipping clip = get_clipping(read.cigartuples, read.reference_start) - transcript.setdefault('clipping', {}).setdefault(s_name, {}).setdefault(clip, 0) - transcript['clipping'][s_name][clip] += cov + transcript.setdefault("clipping", {}).setdefault( + s_name, {} + ).setdefault(clip, 0) + transcript["clipping"][s_name][clip] += cov for transcript_interval in transcript_intervals: transcript = transcript_interval.data - transcript_ranges = transcript.pop('range') + transcript_ranges = transcript.pop("range") _set_ends_of_transcript(transcript, transcript_ranges, s_name) - gene = _add_sample_transcript(self, transcript, chrom, fuzzy_junction, min_exonic_ref_coverage, strictness=strictness) + gene = _add_sample_transcript( + self, + transcript, + chrom, + fuzzy_junction, + min_exonic_ref_coverage, + strictness=strictness, + ) if gene is None: novel.add(transcript_interval) else: - _ = transcript.pop('bc_group', None) + _ = transcript.pop("bc_group", None) # update the total number of reads processed n_reads -= cov @@ -470,15 +704,26 @@ def add_sample_from_bam(self: Transcriptome, fn, sample_name=None, barcode_file= sample_nc_reads[sample] = sample_nc_reads.get(sample, 0) + nc_reads if partial_count: - logger.info('skipped %s reads aligned fraction of less than %s.', partial_count, min_align_fraction) + logger.info( + "skipped %s reads aligned fraction of less than %s.", + partial_count, + min_align_fraction, + ) if skip_bc: - logger.warning('skipped %s reads with barcodes not found in the provided list.', skip_bc) + logger.warning( + "skipped %s reads with barcodes not found in the provided list.", skip_bc + ) if n_secondary > 0: - logger.info('skipped %s secondary alignments (0x100), alignment that failed quality check (0x200) or PCR duplicates (0x400)', n_secondary) + logger.info( + "skipped %s secondary alignments (0x100), alignment that failed quality check (0x200) or PCR duplicates (0x400)", + n_secondary, + ) if unmapped > 0: - logger.info('ignored %s reads marked as unaligned', unmapped) + logger.info("ignored %s reads marked as unaligned", unmapped) if n_lowqual > 0: - logger.info('ignored %s reads with mapping quality < %s', n_lowqual, min_mapqual) + logger.info( + "ignored %s reads with mapping quality < %s", n_lowqual, min_mapqual + ) # merge chimeric reads and assign gene names n_chimeric = dict() @@ -493,60 +738,113 @@ def add_sample_from_bam(self: Transcriptome, fn, sample_name=None, barcode_file= for sample, nc_cov in nc_cov_dict.items(): n_long_intron[sample] = n_long_intron.get(sample, 0) + nc_cov sample_nc_reads[sample] = sample_nc_reads.get(sample, 0) + nc_cov - chim_ignored = sum(len(chim) for chim in chimeric.values()) - sum(n_chimeric.values()) + chim_ignored = sum(len(chim) for chim in chimeric.values()) - sum( + n_chimeric.values() + ) if chim_ignored > 0: - logger.info('ignoring %s chimeric alignments with less than %s reads', chim_ignored, chimeric_mincov) - chained_msg = '' if not sum(n_long_intron.values()) else f' (including {sum(n_long_intron.values())} chained chimeric alignments)' - chimeric_msg = '' if sum(n_chimeric.values()) == 0 else f' and {sum(n_chimeric.values())} chimeric reads with coverage of at least {chimeric_mincov}' - logger.info('imported %s nonchimeric reads%s%s.', sum(sample_nc_reads.values()), chained_msg, chimeric_msg) + logger.info( + "ignoring %s chimeric alignments with less than %s reads", + chim_ignored, + chimeric_mincov, + ) + chained_msg = ( + "" + if not sum(n_long_intron.values()) + else f" (including {sum(n_long_intron.values())} chained chimeric alignments)" + ) + chimeric_msg = ( + "" + if sum(n_chimeric.values()) == 0 + else f" and {sum(n_chimeric.values())} chimeric reads with coverage of at least {chimeric_mincov}" + ) + logger.info( + "imported %s nonchimeric reads%s%s.", + sum(sample_nc_reads.values()), + chained_msg, + chimeric_msg, + ) for readname, (cov, (chrom, strand, exons, _, _), introns) in non_chimeric.items(): novel = dict() try: - tss, pas = (exons[0][0], exons[-1][1]) if strand == '+' else (exons[-1][1], exons[0][0]) + tss, pas = ( + (exons[0][0], exons[-1][1]) + if strand == "+" + else (exons[-1][1], exons[0][0]) + ) transcript = { - 'exons': exons, - 'coverage': cov, - 'TSS': {sample: {tss: c} for sample, c in cov.items()}, - 'PAS': {sample: {pas: c} for sample, c in cov.items()}, - 'strand': strand, - 'chr': chrom, - 'long_intron_chimeric': introns + "exons": exons, + "coverage": cov, + "TSS": {sample: {tss: c} for sample, c in cov.items()}, + "PAS": {sample: {pas: c} for sample, c in cov.items()}, + "strand": strand, + "chr": chrom, + "long_intron_chimeric": introns, } if save_readnames: - transcript['reads'] = {s_name: [readname]} + transcript["reads"] = {s_name: [readname]} except BaseException: - logger.error('\n\n-->%s\n\n', (exons[0][0], exons[-1][1]) if strand == "+" else (exons[-1][1], exons[0][0])) + logger.error( + "\n\n-->%s\n\n", + ( + (exons[0][0], exons[-1][1]) + if strand == "+" + else (exons[-1][1], exons[0][0]) + ), + ) raise - gene = _add_sample_transcript(self, transcript, chrom, fuzzy_junction, min_exonic_ref_coverage, strictness=strictness) + gene = _add_sample_transcript( + self, + transcript, + chrom, + fuzzy_junction, + min_exonic_ref_coverage, + strictness=strictness, + ) if gene is None: novel.setdefault(chrom, []).append(transcript) for chrom in novel: - _ = _add_novel_genes(self, IntervalTree(Interval(transcript['exons'][0][0], transcript['exons'][-1][1], transcript) for transcript in novel[chrom]), chrom) + _ = _add_novel_genes( + self, + IntervalTree( + Interval( + transcript["exons"][0][0], + transcript["exons"][-1][1], + transcript, + ) + for transcript in novel[chrom] + ), + chrom, + ) # self.infos.setdefault('chimeric',{})[s_name]=chimeric # save all chimeric reads (for debugging) for gene in self: - if 'coverage' in gene.data and gene.data['coverage'] is not None: # still valid splice graphs no new transcripts - add a row of zeros to coveage + if ( + "coverage" in gene.data and gene.data["coverage"] is not None + ): # still valid splice graphs no new transcripts - add a row of zeros to coveage gene._set_coverage() for s_name in sample_nc_reads: - kwargs['chimeric_reads'] = n_chimeric.get(s_name, 0) - kwargs['nonchimeric_reads'] = sample_nc_reads.get(s_name, 0) - kwargs['name'] = s_name + kwargs["chimeric_reads"] = n_chimeric.get(s_name, 0) + kwargs["nonchimeric_reads"] = sample_nc_reads.get(s_name, 0) + kwargs["name"] = s_name # self.infos['sample_table'] = self.sample_table.append(kwargs, ignore_index=True) - self.infos['sample_table'] = pd.concat([self.sample_table, pd.DataFrame([kwargs])], ignore_index=True) + self.infos["sample_table"] = pd.concat( + [self.sample_table, pd.DataFrame([kwargs])], ignore_index=True + ) self.make_index() return total_nc_reads_chr -def _add_chimeric(transcriptome: Transcriptome, new_chimeric, min_cov, min_exonic_ref_coverage): - ''' add new chimeric transcripts to transcriptome, if covered by > min_cov reads - ''' +def _add_chimeric( + transcriptome: Transcriptome, new_chimeric, min_cov, min_exonic_ref_coverage +): + """add new chimeric transcripts to transcriptome, if covered by > min_cov reads""" total = {} for new_bp, new_chim_dict in new_chimeric.items(): n_reads = {} for cov_all, _ in new_chim_dict.values(): for sample, cov in cov_all.items(): - n_reads[sample] = n_reads.get(sample, 0)+cov + n_reads[sample] = n_reads.get(sample, 0) + cov n_reads = {sample: cov for sample, cov in n_reads.items() if cov >= min_cov} if not n_reads: continue @@ -557,52 +855,81 @@ def _add_chimeric(transcriptome: Transcriptome, new_chimeric, min_cov, min_exoni # todo: discard invalid (large overlaps, large gaps) # find equivalent chimeric reads for found in transcriptome.chimeric.setdefault(new_bp, []): - if all(splice_identical(ch1[2], ch2[2]) for ch1, ch2 in zip(new_chim[1], found[1])): + if all( + splice_identical(ch1[2], ch2[2]) + for ch1, ch2 in zip(new_chim[1], found[1]) + ): # add coverage for sample in n_reads: found[0][sample] = found[0].get(sample, 0) + n_reads[sample] # adjust start using the maximum range # this part deviates from normal transcripts, where the median tss/pas is used - if found[1][0][1] == '+': # strand of first part - found[1][0][2][0][0] = min(found[1][0][2][0][0], new_chim[1][0][2][0][0]) + if found[1][0][1] == "+": # strand of first part + found[1][0][2][0][0] = min( + found[1][0][2][0][0], new_chim[1][0][2][0][0] + ) else: - found[1][0][2][0][1] = max(found[1][0][2][0][1], new_chim[1][0][2][0][1]) + found[1][0][2][0][1] = max( + found[1][0][2][0][1], new_chim[1][0][2][0][1] + ) # adjust end - if found[1][-1][1] == '+': # strand of last part - found[1][-1][2][-1][1] = max(found[1][-1][2][-1][1], new_chim[1][-1][2][-1][1]) + if found[1][-1][1] == "+": # strand of last part + found[1][-1][2][-1][1] = max( + found[1][-1][2][-1][1], new_chim[1][-1][2][-1][1] + ) else: - found[1][-1][2][-1][0] = min(found[1][-1][2][-1][0], new_chim[1][-1][2][-1][0]) + found[1][-1][2][-1][0] = min( + found[1][-1][2][-1][0], new_chim[1][-1][2][-1][0] + ) break else: # not seen transcriptome.chimeric[new_bp].append(new_chim) for part in new_chim[1]: if part[0] in transcriptome.data: - genes_overlap = [gene for gene in transcriptome.data[part[0]][part[2][0][0]: part[2][-1][1]] if gene.strand == part[1]] - gene, _, _ = _find_matching_gene(genes_overlap, part[2], min_exonic_ref_coverage) # take the best - ignore other hits here + genes_overlap = [ + gene + for gene in transcriptome.data[part[0]][ + part[2][0][0] : part[2][-1][1] + ] + if gene.strand == part[1] + ] + gene, _, _ = _find_matching_gene( + genes_overlap, part[2], min_exonic_ref_coverage + ) # take the best - ignore other hits here if gene is not None: part[4] = gene.name - gene.data.setdefault('chimeric', {})[new_bp] = transcriptome.chimeric[new_bp] + gene.data.setdefault("chimeric", {})[new_bp] = ( + transcriptome.chimeric[new_bp] + ) return total def _breakpoints(chimeric): - ''' gets chimeric aligment as a list and returns list of breakpoints. - each breakpoint is a tuple of (chr1, strand1, pos1, chr2, strand2, pos2) - ''' - return tuple((a[0], a[1], a[2][-1][1] if a[1] == '+' else a[2][0][0], - b[0], b[1], b[2][0][0] if b[1] == '+' else b[2][-1][1]) - for a, b in pairwise(chimeric)) + """gets chimeric aligment as a list and returns list of breakpoints. + each breakpoint is a tuple of (chr1, strand1, pos1, chr2, strand2, pos2) + """ + return tuple( + ( + a[0], + a[1], + a[2][-1][1] if a[1] == "+" else a[2][0][0], + b[0], + b[1], + b[2][0][0] if b[1] == "+" else b[2][-1][1], + ) + for a, b in pairwise(chimeric) + ) def _check_chimeric(chimeric): - ''' prepare the chimeric reads: - 1) sort parts according to read order - 2) compute breakpoints - 3) check if the chimeric read is actually a long introns - return list as nonchimeric - 4) sort into dict by breakpoint - return dict as chimeric + """prepare the chimeric reads: + 1) sort parts according to read order + 2) compute breakpoints + 3) check if the chimeric read is actually a long introns - return list as nonchimeric + 4) sort into dict by breakpoint - return dict as chimeric - chimeric[0] is the coverage dict - chimeric[1] is a list of tuples: chrom,strand,exons,[aligned start, end] ''' + chimeric[0] is the coverage dict + chimeric[1] is a list of tuples: chrom,strand,exons,[aligned start, end]""" chimeric_dict = {} non_chimeric = {} @@ -618,10 +945,13 @@ def _check_chimeric(chimeric): # 2) compute breakpoints bpts = _breakpoints(new_chim[1]) # compute breakpoints # 3) check if long intron alignment splits - merge = [i for i, bp in enumerate(bpts) if - bp[0] == bp[3] and # same chr - bp[1] == bp[4] and # same strand, - 0 < (bp[5] - bp[2] if bp[1] == '+' else bp[2] - bp[5]) < 1e6] # max 1mb gap -> long intron + merge = [ + i + for i, bp in enumerate(bpts) + if bp[0] == bp[3] # same chr + and bp[1] == bp[4] # same strand, + and 0 < (bp[5] - bp[2] if bp[1] == "+" else bp[2] - bp[5]) < 1e6 + ] # max 1mb gap -> long intron # todo: also check that the aligned parts have not big gap or overlap if merge: # new_chim[1] @@ -632,7 +962,7 @@ def _check_chimeric(chimeric): merged_introns.append(intron) for i in merge: # part i is merged into part i+1 - if new_chim[1][i][1] == '+': # merge into next part + if new_chim[1][i][1] == "+": # merge into next part new_chim[1][i + 1][2] = new_chim[1][i][2] + new_chim[1][i + 1][2] new_chim[1][i + 1][3] = new_chim[1][i][3] + new_chim[1][i + 1][3] else: @@ -647,14 +977,21 @@ def _check_chimeric(chimeric): else: assert len(new_chim[1]) == 1 # coverage, part(chrom,strand,exons,[aligned start, end]), and "long introns" - non_chimeric[readname] = [new_chim[0], new_chim[1][0], tuple(merged_introns)] + non_chimeric[readname] = [ + new_chim[0], + new_chim[1][0], + tuple(merged_introns), + ] if skip: - logger.warning('ignored %s chimeric alignments with only one part aligned to specified chromosomes.', skip) + logger.warning( + "ignored %s chimeric alignments with only one part aligned to specified chromosomes.", + skip, + ) return chimeric_dict, non_chimeric def _set_ends_of_transcript(transcript: Transcript, transcript_ranges, sample_name): - start, end = float('inf'), 0 + start, end = float("inf"), 0 starts, ends = {}, {} for range, cov in transcript_ranges.items(): if range[0] < start: @@ -665,128 +1002,262 @@ def _set_ends_of_transcript(transcript: Transcript, transcript_ranges, sample_na ends[range[1]] = ends.get(range[1], 0) + cov # This will be changed again, when assigning to a gene - transcript['exons'][0][0] = start - transcript['exons'][-1][1] = end + transcript["exons"][0][0] = start + transcript["exons"][-1][1] = end cov = sum(transcript_ranges.values()) - s_name = transcript.get('bc_group', sample_name) - transcript['coverage'] = {s_name: cov} - transcript['TSS'] = {s_name: starts if transcript['strand'] == '+' else ends} - transcript['PAS'] = {s_name: ends if transcript['strand'] == '+' else starts} - - -def _add_sample_gene(transcriptome: Transcriptome, gene_start, gene_end, gene_infos, transcript_list, chrom, novel_prefix): - '''add new gene to existing gene in chrom - return gene on success and None if no Gene was found. + s_name = transcript.get("bc_group", sample_name) + transcript["coverage"] = {s_name: cov} + transcript["TSS"] = {s_name: starts if transcript["strand"] == "+" else ends} + transcript["PAS"] = {s_name: ends if transcript["strand"] == "+" else starts} + + +def _add_sample_gene( + transcriptome: Transcriptome, + gene_start, + gene_end, + gene_infos, + transcript_list, + chrom, + novel_prefix, +): + """add new gene to existing gene in chrom - return gene on success and None if no Gene was found. For matching transcripts in gene, transcripts are merged. Coverage, transcript TSS/PAS need to be reset. - Otherwise, a new transcripts are added. In this case, splice graph and coverage have to be reset.''' + Otherwise, a new transcripts are added. In this case, splice graph and coverage have to be reset. + """ if chrom not in transcriptome.data: for transcript in transcript_list: - transcript['annotation'] = (4, {'intergenic': []}) - return transcriptome._add_novel_gene(chrom, gene_start, gene_end, gene_infos['strand'], {'transcripts': transcript_list}, novel_prefix) + transcript["annotation"] = (4, {"intergenic": []}) + return transcriptome._add_novel_gene( + chrom, + gene_start, + gene_end, + gene_infos["strand"], + {"transcripts": transcript_list}, + novel_prefix, + ) # try matching the gene by id try: - best_gene = transcriptome[gene_infos['ID']] - if not best_gene.is_annotated or not has_overlap((best_gene.start, best_gene.end), (gene_start, gene_end)): + best_gene = transcriptome[gene_infos["ID"]] + if not best_gene.is_annotated or not has_overlap( + (best_gene.start, best_gene.end), (gene_start, gene_end) + ): best_gene = None - elsefound = [] # todo: to annotate as fusion, we would need to check for uncovered splice sites, and wether other genes cover them + elsefound = ( + [] + ) # todo: to annotate as fusion, we would need to check for uncovered splice sites, and wether other genes cover them except KeyError: best_gene = None if best_gene is None: # find best matching reference gene from t - genes_overlap_strand = [gene for gene in transcriptome.data[chrom][gene_start: gene_end] if gene.strand == gene_infos['strand']] + genes_overlap_strand = [ + gene + for gene in transcriptome.data[chrom][gene_start:gene_end] + if gene.strand == gene_infos["strand"] + ] if not genes_overlap_strand: covered_splice = 0 else: - splice_junctions = sorted(list({(exon1[1], exon2[0]) for transcript in transcript_list for exon1, exon2 in pairwise(transcript['exons'])})) - splice_sites = np.array([gene.ref_segment_graph.find_splice_sites(splice_junctions) if gene.is_annotated - else _find_splice_sites(splice_junctions, gene.transcripts) for gene in genes_overlap_strand]) + splice_junctions = sorted( + list( + { + (exon1[1], exon2[0]) + for transcript in transcript_list + for exon1, exon2 in pairwise(transcript["exons"]) + } + ) + ) + splice_sites = np.array( + [ + ( + gene.ref_segment_graph.find_splice_sites(splice_junctions) + if gene.is_annotated + else _find_splice_sites(splice_junctions, gene.transcripts) + ) + for gene in genes_overlap_strand + ] + ) sum_ol = splice_sites.sum(1) covered_splice = np.max(sum_ol) if covered_splice > 0: best_idx = np.flatnonzero(sum_ol == covered_splice) best_idx = best_idx[0] not_in_best = np.where(~splice_sites[best_idx])[0] - additional = splice_sites[:, not_in_best] # additional= sites not covered by top gene - elsefound = [(gene.name, not_in_best[a]) for gene, a in zip(genes_overlap_strand, additional) if a.sum() > 0] # genes that cover additional splice sites + additional = splice_sites[ + :, not_in_best + ] # additional= sites not covered by top gene + elsefound = [ + (gene.name, not_in_best[a]) + for gene, a in zip(genes_overlap_strand, additional) + if a.sum() > 0 + ] # genes that cover additional splice sites # notfound = (splice_sites.sum(0) == 0).nonzero()[0].tolist() # not covered splice sites # transcript['novel_splice_sites'] = not_found # cannot be done here, as gene is handled at once. TODO: maybe later? best_gene = genes_overlap_strand[best_idx] else: - genes_overlap_anti = [gene for gene in transcriptome.data[chrom][gene_start: gene_end] if gene.strand != gene_infos['strand']] + genes_overlap_anti = [ + gene + for gene in transcriptome.data[chrom][gene_start:gene_end] + if gene.strand != gene_infos["strand"] + ] for transcript in transcript_list: - transcript['annotation'] = (4, _get_novel_type(transcript['exons'], genes_overlap_anti, genes_overlap_strand)) - return transcriptome._add_novel_gene(chrom, gene_start, gene_end, gene_infos['strand'], {'transcripts': transcript_list}, novel_prefix) + transcript["annotation"] = ( + 4, + _get_novel_type( + transcript["exons"], genes_overlap_anti, genes_overlap_strand + ), + ) + return transcriptome._add_novel_gene( + chrom, + gene_start, + gene_end, + gene_infos["strand"], + {"transcripts": transcript_list}, + novel_prefix, + ) for transcript in transcript_list: - for tr2 in best_gene.transcripts: # check if correction made it identical to existing - if splice_identical(tr2['exons'], transcript['exons']): + for ( + tr2 + ) in best_gene.transcripts: # check if correction made it identical to existing + if splice_identical(tr2["exons"], transcript["exons"]): _combine_transcripts(tr2, transcript) # potentially loosing information break else: - if best_gene.is_annotated and has_overlap((transcript['exons'][0][0], transcript['exons'][-1][1]), (best_gene.start, best_gene.end)): + if best_gene.is_annotated and has_overlap( + (transcript["exons"][0][0], transcript["exons"][-1][1]), + (best_gene.start, best_gene.end), + ): # potentially problematic: elsefound [(genename, idx),...] idx does not refere to transcript splice site try: - transcript['annotation'] = best_gene.ref_segment_graph.get_alternative_splicing(transcript['exons'], elsefound) + transcript["annotation"] = ( + best_gene.ref_segment_graph.get_alternative_splicing( + transcript["exons"], elsefound + ) + ) except Exception: - logger.error('issue categorizing transcript %s with respect to %s', transcript['exons'], str(best_gene)) + logger.error( + "issue categorizing transcript %s with respect to %s", + transcript["exons"], + str(best_gene), + ) raise else: - genes_overlap_strand = [gene for gene in transcriptome.data[chrom][gene_start: gene_end] if gene.strand == gene_infos['strand'] and gene.is_annotated] - genes_overlap_anti = [gene for gene in transcriptome.data[chrom][gene_start: gene_end] if gene.strand != gene_infos['strand'] and gene.is_annotated] - transcript['annotation'] = (4, _get_novel_type(transcript['exons'], genes_overlap_anti, genes_overlap_strand)) # actually may overlap other genes... - best_gene.data.setdefault('transcripts', []).append(transcript) + genes_overlap_strand = [ + gene + for gene in transcriptome.data[chrom][gene_start:gene_end] + if gene.strand == gene_infos["strand"] and gene.is_annotated + ] + genes_overlap_anti = [ + gene + for gene in transcriptome.data[chrom][gene_start:gene_end] + if gene.strand != gene_infos["strand"] and gene.is_annotated + ] + transcript["annotation"] = ( + 4, + _get_novel_type( + transcript["exons"], genes_overlap_anti, genes_overlap_strand + ), + ) # actually may overlap other genes... + best_gene.data.setdefault("transcripts", []).append(transcript) return best_gene -def _add_sample_transcript(transcriptome: Transcriptome, transcript: Transcript, chrom: str, fuzzy_junction: int, min_exonic_ref_coverage: float, genes_overlap=None, strictness=math.inf): - '''add transcript to gene in chrom - return gene on success and None if no Gene was found. +def _add_sample_transcript( + transcriptome: Transcriptome, + transcript: Transcript, + chrom: str, + fuzzy_junction: int, + min_exonic_ref_coverage: float, + genes_overlap=None, + strictness=math.inf, +): + """add transcript to gene in chrom - return gene on success and None if no Gene was found. If matching transcript is found in gene, transcripts are merged. Coverage, transcript TSS/PAS need to be reset. - Otherwise, a new transcript is added. In this case, splice graph and coverage have to be reset.''' + Otherwise, a new transcript is added. In this case, splice graph and coverage have to be reset. + """ if chrom not in transcriptome.data: - transcript['annotation'] = (4, {'intergenic': []}) + transcript["annotation"] = (4, {"intergenic": []}) return None if genes_overlap is None: # At this point the transcript still uses min and max from all reads for start and end - genes_overlap = transcriptome.data[chrom][transcript['exons'][0][0]: transcript['exons'][-1][1]] - genes_overlap_strand = [gene for gene in genes_overlap if gene.strand == transcript['strand']] + genes_overlap = transcriptome.data[chrom][ + transcript["exons"][0][0] : transcript["exons"][-1][1] + ] + genes_overlap_strand = [ + gene for gene in genes_overlap if gene.strand == transcript["strand"] + ] # check if transcript is already there (e.g. from other sample, or in case of long intron chimeric alignments also same sample): for gene in genes_overlap_strand: for transcript2 in gene.transcripts: - if splice_identical(transcript2['exons'], transcript['exons'], strictness=strictness): + if splice_identical( + transcript2["exons"], transcript["exons"], strictness=strictness + ): _combine_transcripts(transcript2, transcript) return gene # we have a new transcript (not seen in this or other samples) # check if gene is already there (e.g. from same or other sample): - gene, additional, not_covered = _find_matching_gene(genes_overlap_strand, transcript['exons'], min_exonic_ref_coverage) + gene, additional, not_covered = _find_matching_gene( + genes_overlap_strand, transcript["exons"], min_exonic_ref_coverage + ) if gene is not None: - if gene.is_annotated: # check for fuzzy junctions (e.g. same small shift at 5' and 3' compared to reference annotation) - shifts = gene.correct_fuzzy_junctions(transcript, fuzzy_junction, modify=True) # this modifies transcript['exons'] + if ( + gene.is_annotated + ): # check for fuzzy junctions (e.g. same small shift at 5' and 3' compared to reference annotation) + shifts = gene.correct_fuzzy_junctions( + transcript, fuzzy_junction, modify=True + ) # this modifies transcript['exons'] if shifts: - transcript.setdefault('fuzzy_junction', []).append(shifts) # keep the info, mainly for testing/statistics - for transcript2 in gene.transcripts: # check if correction made it identical to existing - if splice_identical(transcript2['exons'], transcript['exons'], strictness=strictness): - transcript2.setdefault('fuzzy_junction', []).append(shifts) # keep the info, mainly for testing/statistics + transcript.setdefault("fuzzy_junction", []).append( + shifts + ) # keep the info, mainly for testing/statistics + for ( + transcript2 + ) in ( + gene.transcripts + ): # check if correction made it identical to existing + if splice_identical( + transcript2["exons"], transcript["exons"], strictness=strictness + ): + transcript2.setdefault("fuzzy_junction", []).append( + shifts + ) # keep the info, mainly for testing/statistics _combine_transcripts(transcript2, transcript) return gene - transcript['annotation'] = gene.ref_segment_graph.get_alternative_splicing(transcript['exons'], additional) + transcript["annotation"] = gene.ref_segment_graph.get_alternative_splicing( + transcript["exons"], additional + ) if not_covered: - transcript['novel_splice_sites'] = not_covered # todo: might be changed by fuzzy junction + transcript["novel_splice_sites"] = ( + not_covered # todo: might be changed by fuzzy junction + ) # intersects might have changed due to fuzzy junctions # {'sj_i': sj_i, 'base_i':base_i,'category':SPLICE_CATEGORY[altsplice[1]],'subcategory':altsplice[1]} else: # add to existing novel (e.g. not in reference) gene - start, end = min(transcript['exons'][0][0], gene.start), max(transcript['exons'][-1][1], gene.end) - transcript['annotation'] = (4, _get_novel_type(transcript['exons'], genes_overlap, genes_overlap_strand)) - if start < gene.start or end > gene.end: # range of the novel gene might have changed + start, end = min(transcript["exons"][0][0], gene.start), max( + transcript["exons"][-1][1], gene.end + ) + transcript["annotation"] = ( + 4, + _get_novel_type( + transcript["exons"], genes_overlap, genes_overlap_strand + ), + ) + if ( + start < gene.start or end > gene.end + ): # range of the novel gene might have changed new_gene = Gene(start, end, gene.data, transcriptome) - transcriptome.data[chrom].add(new_gene) # todo: potential issue: in this case two genes may have grown together + transcriptome.data[chrom].add( + new_gene + ) # todo: potential issue: in this case two genes may have grown together transcriptome.data[chrom].remove(gene) gene = new_gene # if additional: @@ -796,58 +1267,87 @@ def _add_sample_transcript(transcriptome: Transcriptome, transcript: Transcript, # transcript[what] = {sample_name: transcript[what]} # if 'reads' in transcript: # transcript['reads'] = {sample_name: transcript['reads']} - gene.data.setdefault('transcripts', []).append(transcript) - gene.data['segment_graph'] = None # gets recomputed on next request - gene.data['coverage'] = None + gene.data.setdefault("transcripts", []).append(transcript) + gene.data["segment_graph"] = None # gets recomputed on next request + gene.data["coverage"] = None else: # new novel gene - transcript['annotation'] = (4, _get_novel_type(transcript['exons'], genes_overlap, genes_overlap_strand)) + transcript["annotation"] = ( + 4, + _get_novel_type(transcript["exons"], genes_overlap, genes_overlap_strand), + ) return gene def _combine_transcripts(established: Transcript, new_transcript: Transcript): - 'merge new_transcript into splice identical established transcript' + "merge new_transcript into splice identical established transcript" try: - for sample_name in new_transcript['coverage']: - established['coverage'][sample_name] = established['coverage'].get(sample_name, 0) + new_transcript['coverage'][sample_name] - for sample_name in new_transcript.get('reads', {}): - established['reads'].setdefault(sample_name, []).extend(new_transcript['reads'][sample_name]) - for side in 'TSS', 'PAS': + for sample_name in new_transcript["coverage"]: + established["coverage"][sample_name] = ( + established["coverage"].get(sample_name, 0) + + new_transcript["coverage"][sample_name] + ) + for sample_name in new_transcript.get("reads", {}): + established["reads"].setdefault(sample_name, []).extend( + new_transcript["reads"][sample_name] + ) + for side in "TSS", "PAS": for sample_name in new_transcript[side]: for pos, cov in new_transcript[side][sample_name].items(): established[side].setdefault(sample_name, {}).setdefault(pos, 0) established[side][sample_name][pos] += cov # find median tss and pas - starts = [pos for sample in established['TSS'] for pos in established['TSS'][sample].items()] - ends = [pos for sample in established['PAS'] for pos in established['PAS'][sample].items()] - if established['strand'] == '-': + starts = [ + pos + for sample in established["TSS"] + for pos in established["TSS"][sample].items() + ] + ends = [ + pos + for sample in established["PAS"] + for pos in established["PAS"][sample].items() + ] + if established["strand"] == "-": starts, ends = ends, starts - established['exons'][0][0] = get_quantiles(starts, [0.5])[0] - established['exons'][-1][1] = get_quantiles(ends, [0.5])[0] - if 'long_intron_chimeric' in new_transcript: - new_introns = set(new_transcript['long_intron_chimeric']) - established_introns = set(established.get('long_intron_chimeric', set())) - established['long_intron_chimeric'] = tuple(new_introns.union(established_introns)) + established["exons"][0][0] = get_quantiles(starts, [0.5])[0] + established["exons"][-1][1] = get_quantiles(ends, [0.5])[0] + if "long_intron_chimeric" in new_transcript: + new_introns = set(new_transcript["long_intron_chimeric"]) + established_introns = set(established.get("long_intron_chimeric", set())) + established["long_intron_chimeric"] = tuple( + new_introns.union(established_introns) + ) except BaseException as e: - logger.error(f'error when merging {new_transcript} into {established}') + logger.error(f"error when merging {new_transcript} into {established}") raise e def _get_novel_type(exons, genes_overlap, genes_overlap_strand): if len(genes_overlap_strand): - exonic_overlap = {gene.id: gene.ref_segment_graph.get_overlap(exons) for gene in genes_overlap_strand if gene.is_annotated} + exonic_overlap = { + gene.id: gene.ref_segment_graph.get_overlap(exons) + for gene in genes_overlap_strand + if gene.is_annotated + } exonic_overlap_genes = [k for k, v in exonic_overlap.items() if v[0] > 0] if len(exonic_overlap_genes) > 0: - return {'genic genomic': exonic_overlap_genes} - return {'intronic': [gene.name for gene in genes_overlap_strand]} + return {"genic genomic": exonic_overlap_genes} + return {"intronic": [gene.name for gene in genes_overlap_strand]} elif len(genes_overlap): - return {'antisense': [gene.name for gene in genes_overlap]} + return {"antisense": [gene.name for gene in genes_overlap]} else: - return {'intergenic': []} + return {"intergenic": []} -def _add_novel_genes(transcriptome: Transcriptome, novel, chrom, spj_iou_th=0, reg_iou_th=.5, gene_prefix='IT_novel_'): +def _add_novel_genes( + transcriptome: Transcriptome, + novel, + chrom, + spj_iou_th=0, + reg_iou_th=0.5, + gene_prefix="IT_novel_", +): '"novel" is a tree of transcript intervals (not Gene objects) ,e.g. from one chromosome, that do not overlap any annotated or unanntoated gene' n_novel = transcriptome.novel_genes idx = {id(transcript): i for i, transcript in enumerate(novel)} @@ -855,11 +1355,21 @@ def _add_novel_genes(transcriptome: Transcriptome, novel, chrom, spj_iou_th=0, r merge = list() for i, transcript in enumerate(novel): merge.append({transcript}) - candidates = [candidate for candidate in novel.overlap(transcript.begin, transcript.end) if candidate.data['strand'] == transcript.data['strand'] and idx[id(candidate)] < i] + candidates = [ + candidate + for candidate in novel.overlap(transcript.begin, transcript.end) + if candidate.data["strand"] == transcript.data["strand"] + and idx[id(candidate)] < i + ] for candidate in candidates: if candidate in merge[i]: continue - if is_same_gene(transcript.data['exons'], candidate.data['exons'], spj_iou_th, reg_iou_th): + if is_same_gene( + transcript.data["exons"], + candidate.data["exons"], + spj_iou_th, + reg_iou_th, + ): # add all transcripts of candidate merge[i].update(merge[idx[id(candidate)]]) for candidate in merge[i]: # update all overlapping (add the current to them) @@ -871,10 +1381,10 @@ def _add_novel_genes(transcriptome: Transcriptome, novel, chrom, spj_iou_th=0, r continue seen.add(id(trS)) trL = [transcript.data for transcript in trS] - strand = trL[0]['strand'] - start = min(transcript['exons'][0][0] for transcript in trL) - end = max(transcript['exons'][-1][1] for transcript in trL) - assert start < end, 'novel gene with start >= end' + strand = trL[0]["strand"] + start = min(transcript["exons"][0][0] for transcript in trL) + end = max(transcript["exons"][-1][1] for transcript in trL) + assert start < end, "novel gene with start >= end" # for transcript in transcriptL: # sample_name = transcript.pop('bc_group', sa) # transcript['coverage'] = {sample_name: transcript['coverage']} @@ -882,27 +1392,43 @@ def _add_novel_genes(transcriptome: Transcriptome, novel, chrom, spj_iou_th=0, r # transcript['PAS'] = {sample_name: transcript['PAS']} # if 'reads' in transcript: # transcript['reads'] = {sample_name: transcript['reads']} - novel_gene_list.append(transcriptome._add_novel_gene(chrom, start, end, strand, {'transcripts': trL}, gene_prefix)) - logging.debug('merging transcripts of novel gene %s: %s', n_novel, trL) + novel_gene_list.append( + transcriptome._add_novel_gene( + chrom, start, end, strand, {"transcripts": trL}, gene_prefix + ) + ) + logging.debug("merging transcripts of novel gene %s: %s", n_novel, trL) return novel_gene_list -def _find_matching_gene(genes_overlap: list[Gene], exons: list[Tuple[int, int]], min_exon_coverage: float): - '''check the splice site intersect of all overlapping genes and return - 1) the gene with most shared splice sites, - 2) names of genes that cover additional splice sites, and - 3) splice sites that are not covered. - If no splice site is shared (and for mono-exon genes) return the gene with largest exonic overlap - :param genes_ol: list of genes that overlap the transcript - :param exons: exon list of the transcript - :param min_exon_coverage: minimum exonic coverage with genes that do not share splice sites to be considered''' +def _find_matching_gene( + genes_overlap: list[Gene], exons: list[Tuple[int, int]], min_exon_coverage: float +): + """check the splice site intersect of all overlapping genes and return + 1) the gene with most shared splice sites, + 2) names of genes that cover additional splice sites, and + 3) splice sites that are not covered. + If no splice site is shared (and for mono-exon genes) return the gene with largest exonic overlap + :param genes_ol: list of genes that overlap the transcript + :param exons: exon list of the transcript + :param min_exon_coverage: minimum exonic coverage with genes that do not share splice sites to be considered + """ transcript_len = sum(exon[1] - exon[0] for exon in exons) splice_junctions = [(exon1[1], exon2[0]) for exon1, exon2 in pairwise(exons)] if genes_overlap: if len(exons) > 1: nomatch = np.zeros(len(splice_junctions) * 2, dtype=bool) # Check reference transcript of reference genes first - splice_sites = np.array([gene.ref_segment_graph.find_splice_sites(splice_junctions) if gene.is_annotated else nomatch for gene in genes_overlap]) + splice_sites = np.array( + [ + ( + gene.ref_segment_graph.find_splice_sites(splice_junctions) + if gene.is_annotated + else nomatch + ) + for gene in genes_overlap + ] + ) sum_overlap = splice_sites.sum(1) # find index of reference gene that covers the most splice sites # resolved issue with tie here, missing FSM due to overlapping gene with extension of FSM transcript @@ -911,7 +1437,16 @@ def _find_matching_gene(genes_overlap: list[Gene], exons: list[Tuple[int, int]], # none found, consider novel genes # TODO: Consider to replace this with gene.segment_graph.find_splice_sites(splice_junctions) # since the graph is cached, this might be faster - splice_sites = np.array([_find_splice_sites(splice_junctions, gene.transcripts) if not gene.is_annotated else nomatch for gene in genes_overlap]) + splice_sites = np.array( + [ + ( + _find_splice_sites(splice_junctions, gene.transcripts) + if not gene.is_annotated + else nomatch + ) + for gene in genes_overlap + ] + ) sum_overlap = splice_sites.sum(1) covered_splice = np.max(sum_overlap) if covered_splice > 0: @@ -920,10 +1455,26 @@ def _find_matching_gene(genes_overlap: list[Gene], exons: list[Tuple[int, int]], # find the transcript with the highest fraction of matching junctions transcript_inter = [] for idx in best_idx: - transcript_list = genes_overlap[idx].ref_transcripts if genes_overlap[idx].is_annotated else genes_overlap[idx].transcripts - transcript_intersects = [(get_intersects(exons, transcript['exons'])) for transcript in transcript_list] - transcript_intersects_fraction = [(junction_intersection / len(transcript_list[idx]['exons']), exon_intersection) - for idx, (junction_intersection, exon_intersection) in enumerate(transcript_intersects)] + transcript_list = ( + genes_overlap[idx].ref_transcripts + if genes_overlap[idx].is_annotated + else genes_overlap[idx].transcripts + ) + transcript_intersects = [ + (get_intersects(exons, transcript["exons"])) + for transcript in transcript_list + ] + transcript_intersects_fraction = [ + ( + junction_intersection + / len(transcript_list[idx]["exons"]), + exon_intersection, + ) + for idx, ( + junction_intersection, + exon_intersection, + ) in enumerate(transcript_intersects) + ] transcript_inter.append(max(transcript_intersects_fraction)) best_idx = best_idx[np.argmax(transcript_inter, axis=0)[0]] else: @@ -932,14 +1483,27 @@ def _find_matching_gene(genes_overlap: list[Gene], exons: list[Tuple[int, int]], # sites not covered by top gene but in exons additional = splice_sites[:, not_in_best] # genes that cover additional splice sites - elsefound = [(gene.name, not_in_best[a]) for gene, a in zip(genes_overlap, additional) if a.sum() > 0] + elsefound = [ + (gene.name, not_in_best[a]) + for gene, a in zip(genes_overlap, additional) + if a.sum() > 0 + ] # not covered splice sites notfound = (splice_sites.sum(0) == 0).nonzero()[0].tolist() return genes_overlap[best_idx], elsefound, notfound # no shared splice sites, return gene with largest overlap # first, check reference here: # 1) if >50% overlap with ref gene -> return best ref gene - overlap = np.array([gene.ref_segment_graph.get_overlap(exons)[0] if gene.is_annotated else 0 for gene in genes_overlap]) + overlap = np.array( + [ + ( + gene.ref_segment_graph.get_overlap(exons)[0] + if gene.is_annotated + else 0 + ) + for gene in genes_overlap + ] + ) best_idx = overlap.argmax() if overlap[best_idx] >= min_exon_coverage * transcript_len: return genes_overlap[best_idx], None, list(range((len(exons) - 1) * 2)) @@ -947,26 +1511,68 @@ def _find_matching_gene(genes_overlap: list[Gene], exons: list[Tuple[int, int]], # len(exons)==1 check all ref transcripts for monoexon gene with overlap>=50% (or min_exon_coverage) overlap = [] for gene in genes_overlap: - overlap_gene = [get_overlap(exons[0], transcript['exons'][0]) for transcript in gene.ref_transcripts if len(transcript['exons']) == 1] + overlap_gene = [ + get_overlap(exons[0], transcript["exons"][0]) + for transcript in gene.ref_transcripts + if len(transcript["exons"]) == 1 + ] overlap.append(max(overlap_gene, default=0)) best_idx = np.argmax(overlap) if overlap[best_idx] >= min_exon_coverage * transcript_len: return genes_overlap[best_idx], None, [] # else return best overlapping novel gene if more than minimum overlap fraction - overlap = [(0, []) if not gene.is_annotated else gene.ref_segment_graph.get_overlap(exons) for gene in genes_overlap] - max_overlap_frac = np.array([0 if overlap[0] == 0 else - max(overlap_transcript / min(transcript_len, sum(exon[1] - exon[0] for exon in transcript["exons"])) - for overlap_transcript, transcript in zip(overlap[1], gene.ref_transcripts)) - for gene, overlap in zip(genes_overlap, overlap)]) + overlap = [ + ( + (0, []) + if not gene.is_annotated + else gene.ref_segment_graph.get_overlap(exons) + ) + for gene in genes_overlap + ] + max_overlap_frac = np.array( + [ + ( + 0 + if overlap[0] == 0 + else max( + overlap_transcript + / min( + transcript_len, + sum(exon[1] - exon[0] for exon in transcript["exons"]), + ) + for overlap_transcript, transcript in zip( + overlap[1], gene.ref_transcripts + ) + ) + ) + for gene, overlap in zip(genes_overlap, overlap) + ] + ) best_idx = max_overlap_frac.argmax() if max_overlap_frac[best_idx] >= min_exon_coverage: - return genes_overlap[best_idx], None, list(range((len(exons) - 1) * 2)) # none of the junctions are covered + return ( + genes_overlap[best_idx], + None, + list(range((len(exons) - 1) * 2)), + ) # none of the junctions are covered # else return best overlapping novel gene if more than minimum overlap fraction - overlap = [0 if gene.is_annotated else _get_overlap(exons, gene.transcripts) for gene in genes_overlap] - max_overlap_frac = np.array([0 if overlap == 0 else overlap_gene / transcript_len for overlap_gene in overlap]) + overlap = [ + 0 if gene.is_annotated else _get_overlap(exons, gene.transcripts) + for gene in genes_overlap + ] + max_overlap_frac = np.array( + [ + 0 if overlap == 0 else overlap_gene / transcript_len + for overlap_gene in overlap + ] + ) best_idx = max_overlap_frac.argmax() if max_overlap_frac[best_idx] >= min_exon_coverage: - return genes_overlap[best_idx], None, list(range((len(exons) - 1) * 2)) # none of the junctions are covered + return ( + genes_overlap[best_idx], + None, + list(range((len(exons) - 1) * 2)), + ) # none of the junctions are covered # TODO: Issue: order matters here, if more than one novel gene with >50%overlap, join them all?) # none of the junctions are covered return None, None, list(range((len(exons) - 1) * 2)) @@ -976,7 +1582,9 @@ def _read_gtf_file(file_name, chromosomes, infer_genes=False, progress_bar=True) exons = dict() # transcript id -> exons transcripts = dict() # gene_id -> transcripts skipped = defaultdict(set) - gene_infos = dict() # 4 tuple: info_dict, gene_start, gene_end, fixed_flag==True if start/end are fixed + gene_infos = ( + dict() + ) # 4 tuple: info_dict, gene_start, gene_end, fixed_flag==True if start/end are fixed cds_start = dict() cds_stop = dict() # with tqdm(total=path.getsize(file_name), unit_scale=True, unit='B', unit_divisor=1024, disable=not progress_bar) as pbar, TabixFile(file_name) as gtf: @@ -984,34 +1592,46 @@ def _read_gtf_file(file_name, chromosomes, infer_genes=False, progress_bar=True) # file_pos = gtf.tell() >> 16 # if pbar.n < file_pos: # pbar.update(file_pos-pbar.n) - openfun = gziplib.open if file_name.endswith('.gz') else open + openfun = gziplib.open if file_name.endswith(".gz") else open with openfun(file_name, "rt") as gtf: for line in gtf: - if line[0] == '#': # ignore header lines + if line[0] == "#": # ignore header lines continue ls = line.split(sep="\t") chr = ls[0] - assert len(ls) == 9, 'unexpected number of fields in gtf line:\n%s' % line + assert len(ls) == 9, "unexpected number of fields in gtf line:\n%s" % line if chromosomes is not None and chr not in chromosomes: - logger.debug('skipping line from chr ' + chr) + logger.debug("skipping line from chr " + chr) continue try: - info = dict([pair.lstrip().split(maxsplit=1) for pair in ls[8].strip().replace('"', '').split(";") if pair]) + info = dict( + [ + pair.lstrip().split(maxsplit=1) + for pair in ls[8].strip().replace('"', "").split(";") + if pair + ] + ) except ValueError: - logger.error('issue with key value pairs from gtf:\n%s', ls[8]) + logger.error("issue with key value pairs from gtf:\n%s", ls[8]) raise # gtf of transcriptome reconstructed by external tools may include entries without strand, which can't be mapped to the genome, so skip them - if ls[6] not in ('+', '-'): - logger.warning('skipping line with unknown strand:\n%s', line) + if ls[6] not in ("+", "-"): + logger.warning("skipping line with unknown strand:\n%s", line) # add this entry to skipped # keys are feature types (ls[2], e.g. gene, transcript, exon) and values are sets of feature ids that are searched in ls[-1] - feature_id = [i.split(' ')[-1].strip('"') for i in ls[-1].split(sep=';') if f'{ls[2]}_id' or f'{ls[2]}_number' in i] + feature_id = [ + i.split(" ")[-1].strip('"') + for i in ls[-1].split(sep=";") + if f"{ls[2]}_id" in i or f"{ls[2]}_number" in i + ] if len(feature_id) == 1: skipped[ls[2]].add(feature_id[0]) else: - logger.debug(f'found {"multiple" if len(feature_id) > 1 else "no"} feature ids in line:\n{line}') + logger.debug( + f'found {"multiple" if len(feature_id) > 1 else "no"} feature ids in line:\n{line}' + ) pass continue @@ -1021,50 +1641,98 @@ def _read_gtf_file(file_name, chromosomes, infer_genes=False, progress_bar=True) if ls[2] == "exon": # logger.debug(line) try: - _ = exons.setdefault(info['transcript_id'], list()).append((start, end)) + _ = exons.setdefault(info["transcript_id"], list()).append( + (start, end) + ) except KeyError: # should not happen if GTF is OK - logger.error("gtf format error: exon without transcript_id\n%s", line) + logger.error( + "gtf format error: exon without transcript_id\n%s", line + ) raise - if infer_genes and 'gene_id' in info: - if info['gene_id'] not in gene_infos[chr]: # new gene - info['strand'] = ls[6] - info['chr'] = chr - _set_alias(info, {'ID': ['gene_id']}) - _set_alias(info, {'name': ['gene_name', 'Name']}, required=False) - ref_info = {k: v for k, v in info.items() if k not in Gene.required_infos + ['name']} - info = {k: info[k] for k in Gene.required_infos + ['name'] if k in info} - info['properties'] = ref_info - gene_infos[chr][info['ID']] = (info, start, end) # start/end not fixed yet (initially set to exon start end) + if infer_genes and "gene_id" in info: + if info["gene_id"] not in gene_infos[chr]: # new gene + info["strand"] = ls[6] + info["chr"] = chr + _set_alias(info, {"ID": ["gene_id"]}) + _set_alias( + info, {"name": ["gene_name", "Name"]}, required=False + ) + ref_info = { + k: v + for k, v in info.items() + if k not in Gene.required_infos + ["name"] + } + info = { + k: info[k] + for k in Gene.required_infos + ["name"] + if k in info + } + info["properties"] = ref_info + gene_infos[chr][info["ID"]] = ( + info, + start, + end, + ) # start/end not fixed yet (initially set to exon start end) else: - known_info = gene_infos[chr][info['gene_id']] - gene_infos[chr][info['gene_id']] = (known_info[0], min(known_info[1], start), max(known_info[2], end)) - if 'transcript_id' in info and info['transcript_id'] not in transcripts.setdefault(info['gene_id'], {}): + known_info = gene_infos[chr][info["gene_id"]] + gene_infos[chr][info["gene_id"]] = ( + known_info[0], + min(known_info[1], start), + max(known_info[2], end), + ) + if "transcript_id" in info and info[ + "transcript_id" + ] not in transcripts.setdefault(info["gene_id"], {}): # new transcript - tr_info = {k: v for k, v in info.items() if 'transcript' in k and k != 'transcript_id'} - transcripts[info["gene_id"]][info["transcript_id"]] = tr_info - elif ls[2] == 'gene': - if 'gene_id' not in info: - logger.warning("gtf format error: gene without gene_id. Skipping line\n%s", line) + tr_info = { + k: v + for k, v in info.items() + if "transcript" in k and k != "transcript_id" + } + transcripts[info["gene_id"]][ + info["transcript_id"] + ] = tr_info + elif ls[2] == "gene": + if "gene_id" not in info: + logger.warning( + "gtf format error: gene without gene_id. Skipping line\n%s", + line, + ) else: # overrule potential entries from exon line - info['strand'] = ls[6] - info['chr'] = chr - _set_alias(info, {'ID': ['gene_id']}) - _set_alias(info, {'name': ['gene_name', 'Name']}, required=False) - ref_info = {k: v for k, v in info.items() if k not in Gene.required_infos + ['name']} - info = {k: info[k] for k in Gene.required_infos + ['name'] if k in info} - info['properties'] = ref_info - gene_infos[chr][info['ID']] = (info, start, end) - elif ls[2] == 'transcript': # overrule potential entries from exon line + info["strand"] = ls[6] + info["chr"] = chr + _set_alias(info, {"ID": ["gene_id"]}) + _set_alias(info, {"name": ["gene_name", "Name"]}, required=False) + ref_info = { + k: v + for k, v in info.items() + if k not in Gene.required_infos + ["name"] + } + info = { + k: info[k] for k in Gene.required_infos + ["name"] if k in info + } + info["properties"] = ref_info + gene_infos[chr][info["ID"]] = (info, start, end) + elif ls[2] == "transcript": # overrule potential entries from exon line try: # keep only transcript related infos (to avoid redundant gene infos) - tr_info = {k: v for k, v in info.items() if 'transcript' in k and k != 'transcript_id'} - _ = transcripts.setdefault(info["gene_id"], dict())[info["transcript_id"]] = tr_info + tr_info = { + k: v + for k, v in info.items() + if "transcript" in k and k != "transcript_id" + } + _ = transcripts.setdefault(info["gene_id"], dict())[ + info["transcript_id"] + ] = tr_info except KeyError: - logger.warning("gtf format errror: transcript must have gene_id and transcript_id, skipping line\n%s", line) - elif ls[2] == 'start_codon' and 'transcript_id' in info: - cds_start[info['transcript_id']] = end if ls[6] == '-' else start - elif ls[2] == 'stop_codon' and 'transcript_id' in info: - cds_stop[info['transcript_id']] = start if ls[6] == '-' else end + logger.warning( + "gtf format errror: transcript must have gene_id and transcript_id, skipping line\n%s", + line, + ) + elif ls[2] == "start_codon" and "transcript_id" in info: + cds_start[info["transcript_id"]] = end if ls[6] == "-" else start + elif ls[2] == "stop_codon" and "transcript_id" in info: + cds_stop[info["transcript_id"]] = start if ls[6] == "-" else end else: # skip other feature types. Only keep a record of feature type without further information in skipped # this usually happens to reference annotation, eg: UTR, CDS etc. @@ -1074,7 +1742,7 @@ def _read_gtf_file(file_name, chromosomes, infer_genes=False, progress_bar=True) def _get_tabix_end(tbx_fh): - for line in tbx_fh.fetch(tbx_fh.contigs[-1]): + for _line in tbx_fh.fetch(tbx_fh.contigs[-1]): pass end = tbx_fh.tell() tbx_fh.seek(0) @@ -1089,47 +1757,69 @@ def _read_gff_file(file_name, chromosomes, progress_bar=True): cds_start = dict() cds_stop = dict() # takes quite some time... add a progress bar? - with tqdm(total=path.getsize(file_name), unit_scale=True, unit='B', unit_divisor=1024, disable=not progress_bar) as pbar, TabixFile(file_name) as gff: + with tqdm( + total=path.getsize(file_name), + unit_scale=True, + unit="B", + unit_divisor=1024, + disable=not progress_bar, + ) as pbar, TabixFile(file_name) as gff: chrom_ids = get_gff_chrom_dict(gff, chromosomes) for line in gff.fetch(): - file_pos = gff.tell() >> 16 # the lower 16 bit are the position within the zipped block + file_pos = ( + gff.tell() >> 16 + ) # the lower 16 bit are the position within the zipped block if pbar.n < file_pos: - pbar.update(file_pos-pbar.n) + pbar.update(file_pos - pbar.n) ls = line.split(sep="\t") if ls[0] not in chrom_ids: continue chrom = chrom_ids[ls[0]] if chromosomes is not None and chrom not in chromosomes: - logger.debug('skipping line %s from chr %s', line, chrom) + logger.debug("skipping line %s from chr %s", line, chrom) continue try: - info = dict([pair.split('=', 1) for pair in ls[8].rstrip(';').split(";")]) # some gff lines end with ';' in gencode 36 + info = dict( + [pair.split("=", 1) for pair in ls[8].rstrip(";").split(";")] + ) # some gff lines end with ';' in gencode 36 except ValueError: - logger.warning("GFF format error in infos (should be ; separated key=value pairs). Skipping line:\n%s", line) + logger.warning( + "GFF format error in infos (should be ; separated key=value pairs). Skipping line:\n%s", + line, + ) start, end = [int(i) for i in ls[3:5]] start -= 1 # to make 0 based if ls[2] == "exon": try: - gff_id = info['Parent'] + gff_id = info["Parent"] exons.setdefault(gff_id, list()).append((start, end)) except KeyError: # should not happen if GFF is OK - logger.warning("GFF format error: no parent found for exon. Skipping line:\n%s", line) - elif ls[2] == 'gene' or 'ID' in info and info['ID'].startswith('gene'): - info['strand'] = ls[6] - info['chr'] = chrom - _set_alias(info, {'ID': ['gene_id']}) - _set_alias(info, {'name': ['Name', 'gene_name']}, required=False) - ref_info = {k: v for k, v in info.items() if k not in Gene.required_infos + ['name']} - info = {k: info[k] for k in Gene.required_infos + ['name'] if k in info} - info['properties'] = ref_info - genes.setdefault(chrom, {})[info['ID']] = (info, start, end) - elif all([v in info for v in ['Parent', "ID"]]) and (ls[2] == 'transcript' or info['Parent'].startswith('gene')): # those denote transcripts - tr_info = {k: v for k, v in info.items() if k.startswith('transcript_')} - transcripts.setdefault(info["Parent"], {})[info['ID']] = tr_info - elif ls[2] == 'start_codon' and 'Parent' in info: - cds_start[info['Parent']] = end if ls[6] == '-' else start - elif ls[2] == 'stop_codon' and 'Parent' in info: - cds_stop[info['Parent']] = start if ls[6] == '-' else end + logger.warning( + "GFF format error: no parent found for exon. Skipping line:\n%s", + line, + ) + elif ls[2] == "gene" or "ID" in info and info["ID"].startswith("gene"): + info["strand"] = ls[6] + info["chr"] = chrom + _set_alias(info, {"ID": ["gene_id"]}) + _set_alias(info, {"name": ["Name", "gene_name"]}, required=False) + ref_info = { + k: v + for k, v in info.items() + if k not in Gene.required_infos + ["name"] + } + info = {k: info[k] for k in Gene.required_infos + ["name"] if k in info} + info["properties"] = ref_info + genes.setdefault(chrom, {})[info["ID"]] = (info, start, end) + elif all([v in info for v in ["Parent", "ID"]]) and ( + ls[2] == "transcript" or info["Parent"].startswith("gene") + ): # those denote transcripts + tr_info = {k: v for k, v in info.items() if k.startswith("transcript_")} + transcripts.setdefault(info["Parent"], {})[info["ID"]] = tr_info + elif ls[2] == "start_codon" and "Parent" in info: + cds_start[info["Parent"]] = end if ls[6] == "-" else start + elif ls[2] == "stop_codon" and "Parent" in info: + cds_stop[info["Parent"]] = start if ls[6] == "-" else end else: # skip other feature types. Only keep a record of feature type without further information in skipped # this usually happens to reference annotation, eg: UTR, CDS etc. @@ -1137,62 +1827,97 @@ def _read_gff_file(file_name, chromosomes, progress_bar=True): return exons, transcripts, genes, cds_start, cds_stop, skipped -def import_ref_transcripts(fn, transcriptome: Transcriptome, file_format, chromosomes=None, gene_categories=None, short_exon_th=25, **kwargs): - '''import transcripts from gff/gtf file (e.g. for a reference) - returns a dict interval trees for the genes''' +def import_ref_transcripts( + fn, + transcriptome: Transcriptome, + file_format, + chromosomes=None, + gene_categories=None, + short_exon_th=25, + **kwargs, +): + """import transcripts from gff/gtf file (e.g. for a reference) + returns a dict interval trees for the genes""" if gene_categories is None: - gene_categories = ['gene'] - if file_format == 'gtf': - exons, transcripts, gene_infos, cds_start, cds_stop, skipped = _read_gtf_file(fn, chromosomes, **kwargs) + gene_categories = ["gene"] + if file_format == "gtf": + exons, transcripts, gene_infos, cds_start, cds_stop, skipped = _read_gtf_file( + fn, chromosomes, **kwargs + ) else: # gff/gff3 - exons, transcripts, gene_infos, cds_start, cds_stop, skipped = _read_gff_file(fn, chromosomes, **kwargs) + exons, transcripts, gene_infos, cds_start, cds_stop, skipped = _read_gff_file( + fn, chromosomes, **kwargs + ) if skipped: - logger.info('skipped the following categories: %s', skipped.keys()) - - logger.debug('construct interval trees for genes...') + logger.info("skipped the following categories: %s", skipped.keys()) + + logger.debug("construct interval trees for genes...") genes: dict[str, IntervalTree[Gene]] = {} for chrom in gene_infos: for info, _, _ in gene_infos[chrom].values(): try: - info['reference'] = info.pop('properties') + info["reference"] = info.pop("properties") except KeyError: logger.error(info) raise - genes[chrom] = IntervalTree(Gene(start, end, info, transcriptome) for info, start, end in gene_infos[chrom].values()) + genes[chrom] = IntervalTree( + Gene(start, end, info, transcriptome) + for info, start, end in gene_infos[chrom].values() + ) # sort the exons - logger.debug('sorting exon positions...') + logger.debug("sorting exon positions...") for tid in exons: exons[tid].sort() - all_genes = set().union(*(set(gene_info.keys()) for gene_info in gene_infos.values())) - missed_genes = [gene_id for gene_id in transcripts.keys() if gene_id not in all_genes] + all_genes = set().union( + *(set(gene_info.keys()) for gene_info in gene_infos.values()) + ) + missed_genes = [ + gene_id for gene_id in transcripts.keys() if gene_id not in all_genes + ] if missed_genes: # logger.debug('/n'.join(gene_id+str(transcript) for gene_id, transcript in missed_genes.items())) notfound = len(missed_genes) found = sum((len(t) for t in genes.values())) - logger.warning('Missing genes! Found gene information in categories %s for %s/%s genes', gene_categories, found, found + notfound) - logger.debug('building gene data structure...') + logger.warning( + "Missing genes! Found gene information in categories %s for %s/%s genes", + gene_categories, + found, + found + notfound, + ) + logger.debug("building gene data structure...") # add transcripts to genes for chrom in genes: for gene in genes[chrom]: gene_id = gene.id transcript = transcripts.get(gene_id, {gene_id: {}}) for transcript_id, transcript_info in transcript.items(): - transcript_info['transcript_id'] = transcript_id + transcript_info["transcript_id"] = transcript_id try: - transcript_info['exons'] = exons[transcript_id] + transcript_info["exons"] = exons[transcript_id] except KeyError: # genes without exons get a single exons transcript - transcript_info['exons'] = [tuple(gene[:2])] + transcript_info["exons"] = [tuple(gene[:2])] # add cds if transcript_id in cds_start and transcript_id in cds_stop: - transcript_info['CDS'] = (cds_start[transcript_id], cds_stop[transcript_id]) if cds_start[transcript_id] < cds_stop[transcript_id] else (cds_stop[transcript_id], cds_start[transcript_id]) - gene.data['reference'].setdefault('transcripts', []).append(transcript_info) + transcript_info["CDS"] = ( + (cds_start[transcript_id], cds_stop[transcript_id]) + if cds_start[transcript_id] < cds_stop[transcript_id] + else (cds_stop[transcript_id], cds_start[transcript_id]) + ) + gene.data["reference"].setdefault("transcripts", []).append( + transcript_info + ) if short_exon_th is not None: - short_exons = {exon for transcript in gene.data['reference']['transcripts'] for exon in transcript['exons'] if exon[1] - exon[0] <= short_exon_th} + short_exons = { + exon + for transcript in gene.data["reference"]["transcripts"] + for exon in transcript["exons"] + if exon[1] - exon[0] <= short_exon_th + } if short_exons: - gene.data['reference']['short_exons'] = short_exons + gene.data["reference"]["short_exons"] = short_exons return genes @@ -1202,22 +1927,35 @@ def import_sqanti_classification(self: Transcriptome, path: str, progress_bar=Tr See https://github.com/ConesaLab/SQANTI3/wiki/Understanding-the-output-of-SQANTI3-QC#classifcols for details. """ - sqanti_df = pd.read_csv(path, sep='\t') - for _, row in tqdm(sqanti_df.iterrows(), total=len(sqanti_df), disable=not progress_bar): - isoform = row['isoform'] - gene_id = '_'.join(isoform.split('_')[:-1]) + sqanti_df = pd.read_csv(path, sep="\t") + for _, row in tqdm( + sqanti_df.iterrows(), total=len(sqanti_df), disable=not progress_bar + ): + isoform = row["isoform"] + gene_id = "_".join(isoform.split("_")[:-1]) if gene_id not in self: - raise KeyError(f'Gene {gene_id} not found in transcriptome. Make sure you passed the correct SQANTI file') + raise KeyError( + f"Gene {gene_id} not found in transcriptome. Make sure you passed the correct SQANTI file" + ) gene = self[gene_id] - transcript_id = int(isoform.split('_')[-1]) + transcript_id = int(isoform.split("_")[-1]) gene.add_sqanti_classification(transcript_id, row) -def export_end_sequences(self: Transcriptome, reference: str, output: str, positive_query, negative_query, - start = True, window = (25, 25), **kwargs): - ''' +def export_end_sequences( + self: Transcriptome, + reference: str, + output: str, + positive_query, + negative_query, + start=True, + window=(25, 25), + **kwargs, +): + """ Generates two fasta files containing the reference sequences in a window around the TSS (or PAS) of all transcripts that meet and not meet the criterium respectively. + :param reference: Path to the reference genome in fasta format or a FastaFile handle :param output: Prefix for the two output files. Files will be generated as positive.fa and negative.fa :param positive_query: Filter string that is passed to iter_transcripts() to select the positive output @@ -1226,33 +1964,49 @@ def export_end_sequences(self: Transcriptome, reference: str, output: str, posit :param window: Tuple of bases specifying the window size around the TSS (PAS) as number of bases (upstream, downstream). Total window size is upstream + downstream + 1 :param kwargs: Additional arguments are passed to both calls of iter_transcripts() - ''' + """ with FastaFile(reference) as ref: known_positions = defaultdict(set) - with open(f'{output}_positive.fa', 'w') as positive: - for gene, transcript_id, transcript in self.iter_transcripts(query=positive_query, **kwargs): - center = transcript['exons'][0][0] if start == (transcript['strand'] == '+') else transcript['exons'][-1][1] - window_here = window if transcript['strand'] == '+' else window[::-1] + with open(f"{output}_positive.fa", "w") as positive: + for gene, transcript_id, transcript in self.iter_transcripts( + query=positive_query, **kwargs + ): + center = ( + transcript["exons"][0][0] + if start == (transcript["strand"] == "+") + else transcript["exons"][-1][1] + ) + window_here = window if transcript["strand"] == "+" else window[::-1] pos = (gene.chrom, center - window_here[0], center + window_here[1] + 1) if pos in known_positions[gene.chrom]: continue seq = ref.fetch(*pos) - positive.write(f'>{gene.id}\t{transcript_id}\t{pos[0]}:{pos[1]}-{pos[2]}\n{seq}\n') + positive.write( + f">{gene.id}\t{transcript_id}\t{pos[0]}:{pos[1]}-{pos[2]}\n{seq}\n" + ) known_positions[gene.chrom].add(pos) - with open(f'{output}_negative.fa', 'w') as negative: - for gene, transcript_id, transcript in self.iter_transcripts(query=negative_query, **kwargs): - center = transcript['exons'][0][0] if start == (transcript['strand'] == '+') else transcript['exons'][-1][1] - window_here = window if transcript['strand'] == '+' else window[::-1] + with open(f"{output}_negative.fa", "w") as negative: + for gene, transcript_id, transcript in self.iter_transcripts( + query=negative_query, **kwargs + ): + center = ( + transcript["exons"][0][0] + if start == (transcript["strand"] == "+") + else transcript["exons"][-1][1] + ) + window_here = window if transcript["strand"] == "+" else window[::-1] pos = (gene.chrom, center - window_here[0], center + window_here[1] + 1) if pos in known_positions[gene.chrom]: continue seq = ref.fetch(*pos) - negative.write(f'>{gene.id}\t{transcript_id}\t{pos[0]}:{pos[1]}-{pos[2]}\n{seq}\n') + negative.write( + f">{gene.id}\t{transcript_id}\t{pos[0]}:{pos[1]}-{pos[2]}\n{seq}\n" + ) known_positions[gene.chrom].add(pos) def collapse_immune_genes(self: Transcriptome, maxgap=300000): - ''' This function collapses annotation of immune genes (IG and TR) of a loci. + """This function collapses annotation of immune genes (IG and TR) of a loci. As immune genes are so variable, classical annotation as a set of transcripts is not meaningfull for those genes. In consequence, each component of an immune gene is usually stored as an individual gene. @@ -1262,40 +2016,65 @@ def collapse_immune_genes(self: Transcriptome, maxgap=300000): Immune genes are recognized by the gff/gtf property "gene_type" set to "IG*_gene" or "TR*_gene". Components within the distance of "maxgap" get collapsed to a single gene called TR/IG_locus_X. :param maxgap: Specify maximum distance between components of the same locus. - ''' - assert not self.samples, 'immune gene collapsing has to be called before adding long read samples' - num = {'IG': 0, 'TR': 0} + """ + assert ( + not self.samples + ), "immune gene collapsing has to be called before adding long read samples" + num = {"IG": 0, "TR": 0} for chrom in self.data: - for strand in ('+', '-'): - immune = {'IG': [], 'TR': []} + for strand in ("+", "-"): + immune = {"IG": [], "TR": []} for gene in self.data[chrom]: - if gene.strand != strand or not gene.is_annotated or 'gene_type' not in gene.data['reference']: + if ( + gene.strand != strand + or not gene.is_annotated + or "gene_type" not in gene.data["reference"] + ): continue - gene_type = gene.data['reference']['gene_type'] - if gene_type[:2] in immune and gene_type[-5:] == '_gene': + gene_type = gene.data["reference"]["gene_type"] + if gene_type[:2] in immune and gene_type[-5:] == "_gene": immune[gene_type[:2]].append(gene) for itype in immune: immune[itype].sort(key=lambda x: (x.start, x.end)) offset = 0 for i, gene in enumerate(immune[itype]): self.data[chrom].remove(gene) - if i + 1 == len(immune[itype]) or gene.end - immune[itype][i + 1].start > maxgap: - ref_info = {'gene_type': f'{itype}_gene', 'transcripts': [t for gene in immune[itype][offset:i + 1] for t in gene.ref_transcripts]} - info = {'ID': f'{itype}_locus_{num[itype]}', 'strand': strand, 'chr': chrom, 'reference': ref_info} + if ( + i + 1 == len(immune[itype]) + or gene.end - immune[itype][i + 1].start > maxgap + ): + ref_info = { + "gene_type": f"{itype}_gene", + "transcripts": [ + t + for gene in immune[itype][offset : i + 1] + for t in gene.ref_transcripts + ], + } + info = { + "ID": f"{itype}_locus_{num[itype]}", + "strand": strand, + "chr": chrom, + "reference": ref_info, + } start = immune[itype][offset].start end = immune[itype][i].end new_gene = Gene(start, end, info, self) self.data[chrom].add(new_gene) num[itype] += 1 offset = i + 1 - logger.info('collapsed %s immunoglobulin loci and %s T-cell receptor loci', num["IG"], num["TR"]) + logger.info( + "collapsed %s immunoglobulin loci and %s T-cell receptor loci", + num["IG"], + num["TR"], + ) # io utility functions -@ experimental -def get_mutations_from_bam(bam_file, genome_file, region, min_cov=.05): - '''not very efficient function to fetch mutations within a region from a bam file - not exported so far''' +@experimental +def get_mutations_from_bam(bam_file, genome_file, region, min_cov=0.05): + """not very efficient function to fetch mutations within a region from a bam file + not exported so far""" mutations = dict() exons = [] n = 0 @@ -1304,7 +2083,12 @@ def get_mutations_from_bam(bam_file, genome_file, region, min_cov=.05): n += 1 exons.append(junctions_from_cigar(read.cigartuples, read.reference_start)) - mutations = get_mutations(read.cigartuples, read.query_sequence, read.reference_start, read.query_qualities) + mutations = get_mutations( + read.cigartuples, + read.query_sequence, + read.reference_start, + read.query_qualities, + ) for pos, ref, alt, qual in mutations: mutations.setdefault(pos, {}).setdefault(alt, [0, ref, []]) mutations[pos][alt][0] += 1 @@ -1313,20 +2097,24 @@ def get_mutations_from_bam(bam_file, genome_file, region, min_cov=.05): if min_cov < 1: min_cov = n * min_cov - mutations = {pos: v for pos, v in mutations.items() if sum(v[alt][0] for alt in v) > min_cov} + mutations = { + pos: v for pos, v in mutations.items() if sum(v[alt][0] for alt in v) > min_cov + } with FastaFile(genome_file) as genome_fh: for pos, v in mutations.items(): for alt in v.values(): - alt[1] = '' if alt[1] < 0 else genome_fh.fetch(region[0], pos, pos + alt[1]) + alt[1] = ( + "" if alt[1] < 0 else genome_fh.fetch(region[0], pos, pos + alt[1]) + ) for transcript in exons: for exon in transcript: if pos >= exon[0] and pos <= exon[1]: - mutations[pos]['cov'] = mutations[pos].get('cov', 0) + 1 + mutations[pos]["cov"] = mutations[pos].get("cov", 0) + 1 return mutations def get_mutations(cigartuples, seq, ref_start, qual): - 'look up the bases affected by mutations as reported in the cigar string' + "look up the bases affected by mutations as reported in the cigar string" # cigar numbers: # 012345678 # MIDNSHP=X @@ -1336,7 +2124,7 @@ def get_mutations(cigartuples, seq, ref_start, qual): for cigar in cigartuples: if cigar[0] in (1, 2, 8): # I(ins), D(del) or X (missmatch): ref = -cigar[1] if cigar[0] == 1 else cigar[1] - alt_base = '' if cigar[0] == 2 else seq[seq_pos:(seq_pos + cigar[1])] + alt_base = "" if cigar[0] == 2 else seq[seq_pos : (seq_pos + cigar[1])] mutations.append((ref_pos, ref, alt_base, qual[seq_pos] if qual else None)) if cigar[0] in (0, 2, 3, 7, 8): # MDN=X -> move forward on reference ref_pos += cigar[1] @@ -1365,7 +2153,10 @@ def get_clipping(cigartuples, pos): return (pos, -cigartuples[0][1]) elif cigartuples[-1][0] == 4: # clipping at the end - get the reference position - return (pos + sum(c[1] for c in cigartuples[:-1] if c[0] in (0, 2, 3, 7, 8)), cigartuples[-1][1]) # MDN=X -> move forward on reference: + return ( + pos + sum(c[1] for c in cigartuples[:-1] if c[0] in (0, 2, 3, 7, 8)), + cigartuples[-1][1], + ) # MDN=X -> move forward on reference: else: return None @@ -1379,7 +2170,12 @@ def _set_alias(d, alias, required=True): except StopIteration: if not required: continue - logger.error('did not find alternative for %s- suggested terms are %s, but have only those keys: %s', pref, alt, list(d)) + logger.error( + "did not find alternative for %s- suggested terms are %s, but have only those keys: %s", + pref, + alt, + list(d), + ) raise for a in alt: d.pop(a, None) @@ -1387,20 +2183,49 @@ def _set_alias(d, alias, required=True): # human readable output def gene_table(self: Transcriptome, **filter_args): # ideas: extra_columns - '''Creates a gene summary table. + """Creates a gene summary table. Exports all genes within region to a table. - :param filter_args: Parameters (e.g. "region", "query") are passed to Transcriptome.iter_genes.''' + :param filter_args: Parameters (e.g. "region", "query") are passed to Transcriptome.iter_genes. + """ - colnames = ['chr', 'start', 'end', 'strand', 'gene_id', 'gene_name', 'n_transcripts'] - rows = [(gene.chrom, gene.start, gene.end, gene.strand, gene.id, gene.name, gene.n_transcripts) for gene in self.iter_genes(**filter_args)] + colnames = [ + "chr", + "start", + "end", + "strand", + "gene_id", + "gene_name", + "n_transcripts", + ] + rows = [ + ( + gene.chrom, + gene.start, + gene.end, + gene.strand, + gene.id, + gene.name, + gene.n_transcripts, + ) + for gene in self.iter_genes(**filter_args) + ] df = pd.DataFrame(rows, columns=colnames) return df -def transcript_table(self: Transcriptome, samples=None, groups=None, coverage=False, tpm=False, tpm_pseudocount=0, extra_columns=None, **filter_args): - '''Creates a transcript table. +def transcript_table( + self: Transcriptome, + samples=None, + groups=None, + coverage=False, + tpm=False, + tpm_pseudocount=0, + extra_columns=None, + **filter_args, +): + """Creates a transcript table. Exports all transcript isoforms within region to a table. @@ -1411,7 +2236,8 @@ def transcript_table(self: Transcriptome, samples=None, groups=None, coverage=Fa :param tpm_pseudocount: This value is added to the coverage for each transcript, before calculating tpm. :param extra_columns: Specify the additional information added to the table. These can be any transcript property as defined by the key in the transcript dict. - :param filter_args: Parameters (e.g. "region", "query", "min_coverage",...) are passed to Transcriptome.iter_transcripts.''' + :param filter_args: Parameters (e.g. "region", "query", "min_coverage",...) are passed to Transcriptome.iter_transcripts. + """ if samples is None: if groups is None: @@ -1429,7 +2255,9 @@ def transcript_table(self: Transcriptome, samples=None, groups=None, coverage=Fa samples = self.samples samples_set = set(samples) samples_set.update(*groups.values()) - assert all(s in self.samples for s in samples_set), 'Not all specified samples are known' + assert all( + s in self.samples for s in samples_set + ), "Not all specified samples are known" if len(samples_set) == len(self.samples): all_samples = True sample_i = slice(None) @@ -1438,28 +2266,62 @@ def transcript_table(self: Transcriptome, samples=None, groups=None, coverage=Fa sample_i = [i for i, sample in enumerate(self.samples) if sample in samples_set] if not isinstance(extra_columns, list): - raise ValueError('extra_columns should be provided as list') - - colnames = ['chr', 'transcript_start', 'transcript_end', 'strand', 'gene_id', 'gene_name', 'transcript_nr', - 'transcript_length', 'num_exons', 'exon_starts', 'exon_ends', 'novelty_class', 'novelty_subclasses'] + raise ValueError("extra_columns should be provided as list") + + colnames = [ + "chr", + "transcript_start", + "transcript_end", + "strand", + "gene_id", + "gene_name", + "transcript_nr", + "transcript_length", + "num_exons", + "exon_starts", + "exon_ends", + "novelty_class", + "novelty_subclasses", + ] colnames += extra_columns rows = [] cov = [] - for gene, transcript_ids, transcripts in self.iter_transcripts(**filter_args, genewise=True): + for gene, transcript_ids, transcripts in self.iter_transcripts( + **filter_args, genewise=True + ): if sample_i: - idx = (slice(None), transcript_ids) if all_samples else np.ix_(sample_i, transcript_ids) + idx = ( + (slice(None), transcript_ids) + if all_samples + else np.ix_(sample_i, transcript_ids) + ) cov.append(gene.coverage[idx]) for transcript_id, transcript in zip(transcript_ids, transcripts): - exons = transcript['exons'] - trlen = sum(e[1]-e[0] for e in exons) - nov_class, subcat = transcript['annotation'] + exons = transcript["exons"] + trlen = sum(e[1] - e[0] for e in exons) + nov_class, subcat = transcript["annotation"] # subcat_string = ';'.join(k if v is None else '{}:{}'.format(k, v) for k, v in subcat.items()) - e_starts, e_ends = (','.join(str(exons[i][j]) for i in range(len(exons))) for j in range(2)) - row = [gene.chrom, exons[0][0], exons[-1][1], gene.strand, gene.id, gene.name, transcript_id, trlen, len(exons), e_starts, e_ends, - SPLICE_CATEGORY[nov_class], ','.join(subcat)] + e_starts, e_ends = ( + ",".join(str(exons[i][j]) for i in range(len(exons))) for j in range(2) + ) + row = [ + gene.chrom, + exons[0][0], + exons[-1][1], + gene.strand, + gene.id, + gene.name, + transcript_id, + trlen, + len(exons), + e_starts, + e_ends, + SPLICE_CATEGORY[nov_class], + ",".join(subcat), + ] for k in extra_columns: - val = transcript.get(k, 'NA') + val = transcript.get(k, "NA") row.append(str(val) if isinstance(val, Iterable) else val) rows.append(row) @@ -1467,35 +2329,60 @@ def transcript_table(self: Transcriptome, samples=None, groups=None, coverage=Fa df = pd.DataFrame(rows, columns=colnames) if cov: df_list = [df] - cov = pd.DataFrame(np.concatenate(cov, 1).T, columns=self.samples if all_samples else [sample for i, sample in enumerate(self.samples) if i in sample_i]) - stab = self.sample_table.set_index('name') + cov = pd.DataFrame( + np.concatenate(cov, 1).T, + columns=( + self.samples + if all_samples + else [sample for i, sample in enumerate(self.samples) if i in sample_i] + ), + ) + stab = self.sample_table.set_index("name") if samples: if coverage: - df_list.append(cov[samples].add_suffix('_coverage')) + df_list.append(cov[samples].add_suffix("_coverage")) if tpm: - total = stab.loc[samples, 'nonchimeric_reads']+tpm_pseudocount*cov.shape[0] - df_list.append(((cov[samples]+tpm_pseudocount)/total*1e6).add_suffix('_tpm')) + total = ( + stab.loc[samples, "nonchimeric_reads"] + + tpm_pseudocount * cov.shape[0] + ) + df_list.append( + ((cov[samples] + tpm_pseudocount) / total * 1e6).add_suffix("_tpm") + ) if groups: - cov_gr = pd.DataFrame({group_name: cov[sample].sum(1) for group_name, sample in groups.items()}) + cov_gr = pd.DataFrame( + { + group_name: cov[sample].sum(1) + for group_name, sample in groups.items() + } + ) if coverage: - df_list.append(cov_gr.add_suffix('_sum_coverage')) + df_list.append(cov_gr.add_suffix("_sum_coverage")) if tpm: - total = {group_name: stab.loc[sample, 'nonchimeric_reads'].sum()+tpm_pseudocount*cov.shape[0] for group_name, sample in groups.items()} - df_list.append(((cov_gr+tpm_pseudocount)/total*1e6).add_suffix('_sum_tpm')) + total = { + group_name: stab.loc[sample, "nonchimeric_reads"].sum() + + tpm_pseudocount * cov.shape[0] + for group_name, sample in groups.items() + } + df_list.append( + ((cov_gr + tpm_pseudocount) / total * 1e6).add_suffix("_sum_tpm") + ) df = pd.concat(df_list, axis=1) return df -@ experimental -def chimeric_table(self: Transcriptome, region=None, query=None): # , star_chimeric=None, illu_len=200): - '''Creates a chimeric table +@experimental +def chimeric_table( + self: Transcriptome, region=None, query=None +): # , star_chimeric=None, illu_len=200): + """Creates a chimeric table This table contains relevant infos about breakpoints and coverage for chimeric genes. :param region: Specify the region, either as (chr, start, end) tuple or as "chr:start-end" string. If omitted specify the complete genome. :param query: Specify transcript filter query. - ''' + """ # todo: correct handeling of three part fusion events not yet implemented # todo: ambiguous alignment handling not yet implemented @@ -1506,16 +2393,44 @@ def chimeric_table(self: Transcriptome, region=None, query=None): # , star_chim raise NotImplementedError chim_tab = list() for bp, chimeric in self.chimeric.items(): - cov = tuple(sum(c.get(sample, 0) for c, _ in chimeric) for sample in self.samples) - genes = [info[4] if info[4] is not None else 'intergenic' for info in chimeric[0][1]] + cov = tuple( + sum(c.get(sample, 0) for c, _ in chimeric) for sample in self.samples + ) + genes = [ + info[4] if info[4] is not None else "intergenic" for info in chimeric[0][1] + ] for i, bp_i in enumerate(bp): - chim_tab.append(('_'.join(genes),) + bp_i[: 3] + (genes[i],) + bp_i[3:] + (genes[i + 1],) + (sum(cov),) + cov) - chim_tab = pd.DataFrame(chim_tab, columns=['name', 'chr1', 'strand1', 'breakpoint1', 'gene1', 'chr2', 'strand2', - 'breakpoint2', 'gene2', 'total_cov'] + [s + '_cov' for s in self.infos['sample_table'].name]) + chim_tab.append( + ("_".join(genes),) + + bp_i[:3] + + (genes[i],) + + bp_i[3:] + + (genes[i + 1],) + + (sum(cov),) + + cov + ) + chim_tab = pd.DataFrame( + chim_tab, + columns=[ + "name", + "chr1", + "strand1", + "breakpoint1", + "gene1", + "chr2", + "strand2", + "breakpoint2", + "gene2", + "total_cov", + ] + + [s + "_cov" for s in self.infos["sample_table"].name], + ) return chim_tab # todo: integrate short read coverage from star files + + # breakpoints = {} # todo: this should be the isoseq breakpoints # offset = 10 + len(self.infos['sample_table']) # for sa_idx, sa in enumerate(star_chimeric): @@ -1536,13 +2451,13 @@ def chimeric_table(self: Transcriptome, region=None, query=None): # , star_chim def openfile(fn, gzip=False): if gzip: - return gziplib.open(fn, 'wt') + return gziplib.open(fn, "wt") else: - return open(fn, 'w', encoding="utf8") + return open(fn, "w", encoding="utf8") -def write_gtf(self: Transcriptome, fn, source='isotools', gzip=False, **filter_args): - ''' +def write_gtf(self: Transcriptome, fn, source="isotools", gzip=False, **filter_args): + """ Exports the transcripts in gtf format to a file. :param fn: The filename to write the gtf. @@ -1550,37 +2465,83 @@ def write_gtf(self: Transcriptome, fn, source='isotools', gzip=False, **filter_a :param region: Specify genomic region to export to gtf. If omitted, export whole genome. :param gzip: Compress the output as gzip. :param filter_args: Specify transcript filter query. - ''' + """ with openfile(fn, gzip) as f: - logger.info('writing %sgtf file to %s', "gzip compressed " if gzip else "", fn) - for gene, transcript_ids, _ in self.iter_transcripts(genewise=True, **filter_args): + logger.info("writing %sgtf file to %s", "gzip compressed " if gzip else "", fn) + for gene, transcript_ids, _ in self.iter_transcripts( + genewise=True, **filter_args + ): lines = gene._to_gtf(transcript_ids=transcript_ids, source=source) - f.write('\n'.join(('\t'.join(str(field) for field in line) for line in lines)) + '\n') - - -def write_fasta(self: Transcriptome, genome_fn, fn, gzip=False, reference=False, protein=False, **filter_args): - ''' + f.write( + "\n".join(("\t".join(str(field) for field in line) for line in lines)) + + "\n" + ) + + +def write_fasta( + self: Transcriptome, + genome_fn, + fn, + gzip=False, + reference=False, + protein=False, + coverage=None, + **filter_args, +): + """ Exports the transcript sequences in fasta format to a file. :param genome_fn: Path to the genome in fastA format. :param reference: Specify whether the sequence is fetched for reference transcripts (True), or long read transcripts (False, default). :param protein: Return protein sequences (ORF) instead of transcript sequences. + :param coverage: By default, the coverage is not added to the header of the fasta. If set, the allowed values are: 'all', or 'sample'. + 'all' - total coverage for all samples; 'sample' - coverage by sample. :param fn: The filename to write the fasta. :param gzip: Compress the output as gzip. :param filter_args: Additional filter arguments (e.g. "region", "gois", "query") are passed to iter_transcripts. - ''' - - with openfile(fn, gzip) as f: - logger.info('writing %sfasta file to %s', "gzip compressed " if gzip else "", fn) - for gene, transcript_ids, _ in self.iter_transcripts(genewise=True, **filter_args): - tr_seqs = gene.get_sequence(genome_fn, transcript_ids, reference=reference, protein=protein) - f.write('\n'.join(f'>{gene.id}_{k} gene={gene.name}\n{v}' for k,v in tr_seqs.items()) + '\n') + """ + if coverage: + assert coverage in [ + "all", + "sample", + ], 'if coverage is set, it must be "all", or "sample"' -def export_alternative_splicing(self: Transcriptome, out_dir, out_format='mats', reference=False, min_total=100, - min_alt_fraction=.1, samples=None, region=None, query=None, progress_bar=True): - '''Exports alternative splicing events defined by the transcriptome. + with openfile(fn, gzip) as f: + logger.info( + "writing %sfasta file to %s", "gzip compressed " if gzip else "", fn + ) + for gene, transcript_ids, _ in self.iter_transcripts( + genewise=True, **filter_args + ): + tr_seqs = gene.get_sequence( + genome_fn, transcript_ids, reference=reference, protein=protein + ) + if len(tr_seqs) > 0: + f.write( + "\n".join( + f">{gene.id}_{k} gene={gene.name}" + f'{(" coverage=" + (str(gene.coverage[:, k].sum()) if coverage == "all" else str(gene.coverage[:, k])) if coverage else "")}\n{v}' + for k, v in tr_seqs.items() + ) + + "\n" + ) + + +def export_alternative_splicing( + self: Transcriptome, + out_dir, + out_format="mats", + reference=False, + min_total=100, + min_alt_fraction=0.1, + samples=None, + region=None, + query=None, + progress_bar=True, +): + """Exports alternative splicing events defined by the transcriptome. This is intended to integrate splicing event analysis from short read data. Tools for short read data implement different formats for the import of events. @@ -1599,38 +2560,84 @@ def export_alternative_splicing(self: Transcriptome, out_dir, out_format='mats', In this case the following parameters are ignored :param samples: Specify the samples to consider :param min_total: Minimum total coverage over all selected samples. - :param min_alt_fraction: Minimum fraction of reads supporting the alternative.''' - if out_format == 'miso': - file_name = 'isotools_miso_{}.gff' + :param min_alt_fraction: Minimum fraction of reads supporting the alternative.""" + if out_format == "miso": + file_name = "isotools_miso_{}.gff" alt_splice_export = _miso_alt_splice_export - elif out_format == 'mats': - file_name = 'fromGTF.{}.txt' + elif out_format == "mats": + file_name = "fromGTF.{}.txt" alt_splice_export = _mats_alt_splice_export else: raise ValueError('out_format must be "miso" or "mats"') - types = {'ES': 'SE', '3AS': 'A3SS', '5AS': 'A5SS', 'IR': 'RI', 'ME': 'MXE'} # it looks like these are the "official" identifiers? - out_file = {st: out_dir + '/' + file_name.format(st) for st in types.values()} + types = { + "ES": "SE", + "3AS": "A3SS", + "5AS": "A5SS", + "IR": "RI", + "ME": "MXE", + } # it looks like these are the "official" identifiers? + out_file = {st: out_dir + "/" + file_name.format(st) for st in types.values()} if samples is None: samples = self.samples - assert all(sample in self.samples for sample in samples), 'not all specified samples found' + assert all( + sample in self.samples for sample in samples + ), "not all specified samples found" sample_dict = {sample: i for i, sample in enumerate(self.samples)} sidx = np.array([sample_dict[sample] for sample in samples]) - assert 0 < min_alt_fraction < .5, 'min_alt_fraction must be > 0 and < 0.5' + assert 0 < min_alt_fraction < 0.5, "min_alt_fraction must be > 0 and < 0.5" count = {st: 0 for st in types.values()} with ExitStack() as stack: - fh = {st: stack.enter_context(open(out_file[st], 'w')) for st in out_file} - if out_format == 'mats': # prepare mats header - base_header = ['ID', 'GeneID', 'geneSymbol', 'chr', 'strand'] - add_header = {'SE': ['exonStart_0base', 'exonEnd', 'upstreamES', 'upstreamEE', 'downstreamES', 'downstreamEE'], - 'RI': ['riExonStart_0base', 'riExonEnd', 'upstreamES', 'upstreamEE', 'downstreamES', 'downstreamEE'], - 'MXE': ['1stExonStart_0base', '1stExonEnd', '2ndExonStart_0base', '2ndExonEnd', 'upstreamES', - 'upstreamEE', 'downstreamES', 'downstreamEE'], - 'A3SS': ['longExonStart_0base', 'longExonEnd', 'shortES', 'shortEE', 'flankingES', 'flankingEE'], - 'A5SS': ['longExonStart_0base', 'longExonEnd', 'shortES', 'shortEE', 'flankingES', 'flankingEE']} + fh = {st: stack.enter_context(open(out_file[st], "w")) for st in out_file} + if out_format == "mats": # prepare mats header + base_header = ["ID", "GeneID", "geneSymbol", "chr", "strand"] + add_header = { + "SE": [ + "exonStart_0base", + "exonEnd", + "upstreamES", + "upstreamEE", + "downstreamES", + "downstreamEE", + ], + "RI": [ + "riExonStart_0base", + "riExonEnd", + "upstreamES", + "upstreamEE", + "downstreamES", + "downstreamEE", + ], + "MXE": [ + "1stExonStart_0base", + "1stExonEnd", + "2ndExonStart_0base", + "2ndExonEnd", + "upstreamES", + "upstreamEE", + "downstreamES", + "downstreamEE", + ], + "A3SS": [ + "longExonStart_0base", + "longExonEnd", + "shortES", + "shortEE", + "flankingES", + "flankingEE", + ], + "A5SS": [ + "longExonStart_0base", + "longExonEnd", + "shortES", + "shortEE", + "flankingES", + "flankingEE", + ], + } for st in fh: - fh[st].write('\t'.join(base_header + add_header[st]) + '\n') + fh[st].write("\t".join(base_header + add_header[st]) + "\n") for gene in self.iter_genes(region, query, progress_bar=progress_bar): if reference and not gene.is_annotated: continue @@ -1638,48 +2645,144 @@ def export_alternative_splicing(self: Transcriptome, out_dir, out_format='mats', continue seg_graph = gene.ref_segment_graph if reference else gene.segment_graph - for setA, setB, nodeX, nodeY, splice_type in seg_graph.find_splice_bubbles(types=('ES', '3AS', '5AS', 'IR', 'ME')): + for setA, setB, nodeX, nodeY, splice_type in seg_graph.find_splice_bubbles( + types=("ES", "3AS", "5AS", "IR", "ME") + ): if not reference: junction_cov = gene.coverage[np.ix_(sidx, setA)].sum(1) total_cov = gene.coverage[np.ix_(sidx, setB)].sum(1) + junction_cov - if total_cov.sum() < min_total or (not min_alt_fraction < junction_cov.sum() / total_cov.sum() < 1 - min_alt_fraction): + if total_cov.sum() < min_total or ( + not min_alt_fraction + < junction_cov.sum() / total_cov.sum() + < 1 - min_alt_fraction + ): continue st = types[splice_type] - lines = alt_splice_export(setA, setB, nodeX, nodeY, st, reference, gene, count[st]) + lines = alt_splice_export( + setA, setB, nodeX, nodeY, st, reference, gene, count[st] + ) if lines: count[st] += len(lines) - fh[st].write('\n'.join(('\t'.join(str(field) for field in line) for line in lines)) + '\n') + fh[st].write( + "\n".join( + ("\t".join(str(field) for field in line) for line in lines) + ) + + "\n" + ) -def _miso_alt_splice_export(setA, setB, nodeX, nodeY, splice_type, reference, gene, offset): +def _miso_alt_splice_export( + setA, setB, nodeX, nodeY, splice_type, reference, gene, offset +): seg_graph = gene.ref_segment_graph if reference else gene.segment_graph - event_id = f'{gene.chrom}:{seg_graph[nodeX].end}-{seg_graph[nodeY].start}_st' + event_id = f"{gene.chrom}:{seg_graph[nodeX].end}-{seg_graph[nodeY].start}_st" # TODO: Mutually exclusives extend beyond nodeY - and have potentially multiple A "mRNAs" # TODO: is it possible to extend exons at nodeX and Y - if all/"most" transcript from setA and B agree? # if st=='ME': # nodeY=min(seg_graph._pas[setA]) lines = [] - lines.append([gene.chrom, splice_type, 'gene', seg_graph[nodeX].start, seg_graph[nodeY].end, '.', gene.strand, '.', f'ID={event_id};gene_name={gene.name};gene_id={gene.id}']) + lines.append( + [ + gene.chrom, + splice_type, + "gene", + seg_graph[nodeX].start, + seg_graph[nodeY].end, + ".", + gene.strand, + ".", + f"ID={event_id};gene_name={gene.name};gene_id={gene.id}", + ] + ) # lines.append((gene.chrom, st, 'mRNA', seg_graph[nodeX].start, seg_graph[nodeY].end, '.',gene.strand, '.', f'Parent={event_id};ID={event_id}.A')) # lines.append((gene.chrom, st, 'exon', seg_graph[nodeX].start, seg_graph[nodeX].end, '.',gene.strand, '.', f'Parent={event_id}.A;ID={event_id}.A.up')) # lines.append((gene.chrom, st, 'exon', seg_graph[nodeY].start, seg_graph[nodeY].end, '.',gene.strand, '.', f'Parent={event_id}.A;ID={event_id}.A.down')) - for i, exons in enumerate({tuple(seg_graph._get_all_exons(nodeX, nodeY, transcript)) for transcript in setA}): - lines.append((gene.chrom, splice_type, 'mRNA', exons[0][0], exons[-1][1], '.', gene.strand, '.', f'Parent={event_id};ID={event_id}.A{i}')) + for i, exons in enumerate( + { + tuple(seg_graph._get_all_exons(nodeX, nodeY, transcript)) + for transcript in setA + } + ): + lines.append( + ( + gene.chrom, + splice_type, + "mRNA", + exons[0][0], + exons[-1][1], + ".", + gene.strand, + ".", + f"Parent={event_id};ID={event_id}.A{i}", + ) + ) lines[0][3] = min(lines[0][3], lines[-1][3]) lines[0][4] = max(lines[0][4], lines[-1][4]) for j, exon in enumerate(exons): - lines.append((gene.chrom, splice_type, 'exon', exon[0], exon[1], '.', gene.strand, '.', f'Parent={event_id}.A{i};ID={event_id}.A{i}.{j}')) - for i, exons in enumerate({tuple(seg_graph._get_all_exons(nodeX, nodeY, transcript)) for transcript in setB}): - lines.append((gene.chrom, splice_type, 'mRNA', exons[0][0], exons[-1][1], '.', gene.strand, '.', f'Parent={event_id};ID={event_id}.B{i}')) + lines.append( + ( + gene.chrom, + splice_type, + "exon", + exon[0], + exon[1], + ".", + gene.strand, + ".", + f"Parent={event_id}.A{i};ID={event_id}.A{i}.{j}", + ) + ) + for i, exons in enumerate( + { + tuple(seg_graph._get_all_exons(nodeX, nodeY, transcript)) + for transcript in setB + } + ): + lines.append( + ( + gene.chrom, + splice_type, + "mRNA", + exons[0][0], + exons[-1][1], + ".", + gene.strand, + ".", + f"Parent={event_id};ID={event_id}.B{i}", + ) + ) lines[0][3] = min(lines[0][3], lines[-1][3]) lines[0][4] = max(lines[0][4], lines[-1][4]) for j, exon in enumerate(exons): - lines.append((gene.chrom, splice_type, 'exon', exon[0], exon[1], '.', gene.strand, '.', f'Parent={event_id}.B{i};ID={event_id}.B{i}.{j}')) + lines.append( + ( + gene.chrom, + splice_type, + "exon", + exon[0], + exon[1], + ".", + gene.strand, + ".", + f"Parent={event_id}.B{i};ID={event_id}.B{i}.{j}", + ) + ) return lines -def _mats_alt_splice_export(setA, setB, nodeX, nodeY, st, reference, gene, offset, use_top_isoform=True, use_top_alternative=True): +def _mats_alt_splice_export( + setA, + setB, + nodeX, + nodeY, + st, + reference, + gene, + offset, + use_top_isoform=True, + use_top_alternative=True, +): # use_top_isoform and use_top_alternative are ment to simplify the output, in order to not confuse rMATS with to many options # 'ID','GeneID','geneSymbol','chr','strand' # and ES/EE for the relevant exons @@ -1693,28 +2796,42 @@ def _mats_alt_splice_export(setA, setB, nodeX, nodeY, st, reference, gene, offse seg_graph = gene.ref_segment_graph if reference else gene.segment_graph lines = [] - if gene.chrom[: 3] != 'chr': - chrom = 'chr' + gene.chrom + if gene.chrom[:3] != "chr": + chrom = "chr" + gene.chrom else: chrom = gene.chrom if use_top_isoform: # use flanking exons from top isoform - all_transcript_ids = setA+setB + all_transcript_ids = setA + setB if not reference: # most covered isoform - top_isoform = all_transcript_ids[gene.coverage[:, all_transcript_ids].sum(0).argmax()] # top covered isoform across all samples + top_isoform = all_transcript_ids[ + gene.coverage[:, all_transcript_ids].sum(0).argmax() + ] # top covered isoform across all samples nodeX_start = seg_graph._get_exon_start(top_isoform, nodeX) nodeY_end = seg_graph._get_exon_end(top_isoform, nodeY) else: # for reference: most frequent - nodeX_start = Counter([seg_graph._get_exon_start(n, nodeX) for n in all_transcript_ids]).most_common(1)[0][0] - nodeY_end = Counter([seg_graph._get_exon_end(n, nodeY) for n in all_transcript_ids]).most_common(1)[0][0] - - exonsA = ((seg_graph[nodeX_start].start, seg_graph[nodeX].end), (seg_graph[nodeY].start, seg_graph[nodeY_end].end)) # flanking/outer "exons" of setA() + nodeX_start = Counter( + [seg_graph._get_exon_start(n, nodeX) for n in all_transcript_ids] + ).most_common(1)[0][0] + nodeY_end = Counter( + [seg_graph._get_exon_end(n, nodeY) for n in all_transcript_ids] + ).most_common(1)[0][0] + + exonsA = ( + (seg_graph[nodeX_start].start, seg_graph[nodeX].end), + (seg_graph[nodeY].start, seg_graph[nodeY_end].end), + ) # flanking/outer "exons" of setA() else: - exonsA = ((seg_graph[nodeX].start, seg_graph[nodeX].end), (seg_graph[nodeY].start, seg_graph[nodeY].end)) # flanking/outer "exons" of setA + exonsA = ( + (seg_graph[nodeX].start, seg_graph[nodeX].end), + (seg_graph[nodeY].start, seg_graph[nodeY].end), + ) # flanking/outer "exons" of setA # _get_all_exons does not extend the exons beyond the nodeX/Y - alternatives = [tuple(seg_graph._get_all_exons(nodeX, nodeY, b_tr)) for b_tr in setB] + alternatives = [ + tuple(seg_graph._get_all_exons(nodeX, nodeY, b_tr)) for b_tr in setB + ] if use_top_alternative: if reference: c = Counter(alternatives) @@ -1728,46 +2845,68 @@ def _mats_alt_splice_export(setA, setB, nodeX, nodeY, st, reference, gene, offse alternatives = set(alternatives) for exonsB in alternatives: exons_sel = [] - if st in ['A3SS', 'A5SS'] and len(exonsB) == 2: + if st in ["A3SS", "A5SS"] and len(exonsB) == 2: if exonsA[0][1] == exonsB[0][1]: # A5SS on - strand or A3SS on + strand - exons_sel.append([(exonsB[1][0], exonsA[1][1]), exonsA[1], exonsA[0]]) # long short flanking + exons_sel.append( + [(exonsB[1][0], exonsA[1][1]), exonsA[1], exonsA[0]] + ) # long short flanking else: # A5SS on + strand or A3SS on - strand - exons_sel.append([(exonsA[0][0], exonsB[0][1]), exonsA[0], exonsA[1]]) # long short flanking - elif st == 'SE' and len(exonsB) == 3: + exons_sel.append( + [(exonsA[0][0], exonsB[0][1]), exonsA[0], exonsA[1]] + ) # long short flanking + elif st == "SE" and len(exonsB) == 3: # just to be sure everything is consistent - assert exonsA[0][1] == exonsB[0][1] and exonsA[1][0] == exonsB[2][0], f'invalid exon skipping {exonsA} vs {exonsB}' + assert ( + exonsA[0][1] == exonsB[0][1] and exonsA[1][0] == exonsB[2][0] + ), f"invalid exon skipping {exonsA} vs {exonsB}" # e_order = (1, 0, 2) if gene.strand == '+' else (1, 2, 0) - exons_sel.append([exonsB[i] for i in (1, 0, 2)]) # skipped, upstream, downstream - elif st == 'RI' and len(exonsB) == 1: - exons_sel.append([(exonsA[0][0], exonsA[1][1]), exonsA[0], exonsA[1]]) # retained, upstream, downstream + exons_sel.append( + [exonsB[i] for i in (1, 0, 2)] + ) # skipped, upstream, downstream + elif st == "RI" and len(exonsB) == 1: + exons_sel.append( + [(exonsA[0][0], exonsA[1][1]), exonsA[0], exonsA[1]] + ) # retained, upstream, downstream # if gene.strand == '+' else [exonsB[0], exonsA[1], exonsA[0]]) - elif st == 'MXE' and len(exonsB) == 3: + elif st == "MXE" and len(exonsB) == 3: # nodeZ=next(idx for idx,n in enumerate(seg_graph) if n.start==exonsB[-1][0]) # multiple exonA possibilities, so we need to check all of them - for exonsA_ME in {tuple(seg_graph._get_all_exons(nodeX, nodeY, a_tr)) for a_tr in setA}: + for exonsA_ME in { + tuple(seg_graph._get_all_exons(nodeX, nodeY, a_tr)) for a_tr in setA + }: if len(exonsA_ME) != 3: # complex events are not possible in rMATS continue - assert exonsA_ME[0] == exonsB[0] and exonsA_ME[2] == exonsB[2] # should always be true + assert ( + exonsA_ME[0] == exonsB[0] and exonsA_ME[2] == exonsB[2] + ) # should always be true # assert exonsA_ME[0][1] == exonsA[0][1] and exonsA_ME[2][0] == exonsA[1][0] # should always be true # '1st','2nd', 'upstream', 'downstream' exons_sel.append([exonsB[1], exonsA_ME[1], exonsA[0], exonsA[1]]) - for exons in exons_sel: # usually there is only one rMATS event per exonB, but for MXE we may get several + for ( + exons + ) in ( + exons_sel + ): # usually there is only one rMATS event per exonB, but for MXE we may get several # lines.append([f'"{gene.id}"', f'"{gene.name}"', chrom, gene.strand] + [pos for exon in exons for pos in ((exon[1],exon[0]) if gene.strand == '-' else exon)]) - lines.append([f'"{gene.id}"', f'"{gene.name}"', chrom, gene.strand] + [pos for exon in exons for pos in exon]) # no need to reverse the order of exon start/end + lines.append( + [f'"{gene.id}"', f'"{gene.name}"', chrom, gene.strand] + + [pos for exon in exons for pos in exon] + ) # no need to reverse the order of exon start/end return [[offset + count] + l for count, l in enumerate(lines)] def get_gff_chrom_dict(gff: TabixFile, chromosomes): - 'fetch chromosome ids - in case they use ids in gff for the chromosomes' + "fetch chromosome ids - in case they use ids in gff for the chromosomes" chrom = {} for c in gff.contigs: # loggin.debug ("---"+c) - for line in gff.fetch(c, 1, 2): # chromosomes span the entire chromosome, so they can be fetched like that + for line in gff.fetch( + c, 1, 2 + ): # chromosomes span the entire chromosome, so they can be fetched like that if line[1] == "C": ls = line.split(sep="\t") if ls[2] == "region": - info = dict([pair.split("=") - for pair in ls[8].split(";")]) + info = dict([pair.split("=") for pair in ls[8].split(";")]) if "chromosome" in info.keys(): if chromosomes is None or info["chromosome"] in chromosomes: chrom[ls[0]] = info["chromosome"] @@ -1781,31 +2920,53 @@ def get_gff_chrom_dict(gff: TabixFile, chromosomes): class IntervalArray: - '''drop in replacement for the interval tree during construction, with faster lookup''' + """drop in replacement for the interval tree during construction, with faster lookup""" def __init__(self, total_size, bin_size=1e4): self.obj: dict[str, Interval] = {} - self.data: list[set[int]] = [set() for _ in range(int((total_size) // bin_size) + 1)] + self.data: list[set[int]] = [ + set() for _ in range(int((total_size) // bin_size) + 1) + ] self.bin_size = bin_size def overlap(self, begin, end): try: - candidates = {obj_id - for idx in range(int(begin // self.bin_size), int(end // self.bin_size) + 1) - for obj_id in self.data[idx]} + candidates = { + obj_id + for idx in range( + int(begin // self.bin_size), int(end // self.bin_size) + 1 + ) + for obj_id in self.data[idx] + } except IndexError: - logger.error('requesting interval between %s and %s, but array is allocated only until position %s', begin, end, len(self.data)*self.bin_size) + logger.error( + "requesting interval between %s and %s, but array is allocated only until position %s", + begin, + end, + len(self.data) * self.bin_size, + ) raise # this assumes object has range obj[0] to obj[1] - return (self.obj[obj_id] for obj_id in candidates if has_overlap((begin, end), self.obj[obj_id])) + return ( + self.obj[obj_id] + for obj_id in candidates + if has_overlap((begin, end), self.obj[obj_id]) + ) def add(self, obj: Interval): self.obj[id(obj)] = obj try: - for idx in range(int(obj.begin // self.bin_size), int(obj.end // self.bin_size) + 1): + for idx in range( + int(obj.begin // self.bin_size), int(obj.end // self.bin_size) + 1 + ): self.data[idx].add(id(obj)) except IndexError: - logger.error('adding interval from %s to %s, but array is allocated only until position %s', obj.begin, obj.end, len(self.data)*self.bin_size) + logger.error( + "adding interval from %s to %s, but array is allocated only until position %s", + obj.begin, + obj.end, + len(self.data) * self.bin_size, + ) raise def __len__(self): diff --git a/src/isotools/_transcriptome_stats.py b/src/isotools/_transcriptome_stats.py index a18a6b9..841c5a5 100644 --- a/src/isotools/_transcriptome_stats.py +++ b/src/isotools/_transcriptome_stats.py @@ -1,4 +1,10 @@ -from scipy.stats import binom, norm, chi2, betabinom, nbinom # pylint: disable-msg=E0611 +from scipy.stats import ( + binom, + norm, + chi2, + betabinom, + nbinom, +) # pylint: disable-msg=E0611 from scipy.special import gammaln, polygamma # pylint: disable-msg=E0611 from scipy.optimize import minimize, minimize_scalar import statsmodels.stats.multitest as multi @@ -16,7 +22,7 @@ from .splice_graph import SegmentGraph from ._utils import _filter_function, ASEType, str_var_triplet -logger = logging.getLogger('isotools') +logger = logging.getLogger("isotools") # differential splicing @@ -49,10 +55,21 @@ def binom_lr_test(x, n): def loglike_betabinom(params, k, n): - '''returns log likelihood of betabinomial and its partial derivatives''' + """returns log likelihood of betabinomial and its partial derivatives""" a, b = params - logpdf = gammaln(n + 1) + gammaln(k + a) + gammaln(n - k + b) + gammaln(a + b) - \ - (gammaln(k + 1) + gammaln(n - k + 1) + gammaln(a) + gammaln(b) + gammaln(n + a + b)) + logpdf = ( + gammaln(n + 1) + + gammaln(k + a) + + gammaln(n - k + b) + + gammaln(a + b) + - ( + gammaln(k + 1) + + gammaln(n - k + 1) + + gammaln(a) + + gammaln(b) + + gammaln(n + a + b) + ) + ) e = polygamma(0, a + b) - polygamma(0, n + a + b) da = e + polygamma(0, k + a) - polygamma(0, a) db = e + polygamma(0, n - k + b) - polygamma(0, b) @@ -60,11 +77,11 @@ def loglike_betabinom(params, k, n): def betabinom_ml(xi, ni): - '''Calculate maximum likelihood parameter of beta binomial distribution for a group of samples with xi successes and ni trials. + """Calculate maximum likelihood parameter of beta binomial distribution for a group of samples with xi successes and ni trials. :param xi: number of successes, here coverage of the alternative for all samples of the group as 1d numpy array :param ni: number of trials, here total coverage for the two sample groups for all samples of the group as 1d numpy array - ''' + """ # x and n must be np arrays if sum(ni) == 0: params = params_alt = None, None @@ -75,39 +92,62 @@ def betabinom_ml(xi, ni): d = prob.var() success = True if d == 0: # just one sample? or all exactly the same proportion - params = params_alt = m, None # in this case the betabinomial reduces to the binomial + params = params_alt = ( + m, + None, + ) # in this case the betabinomial reduces to the binomial else: d = max(d, 1e-6) # to avoid division by 0 - e = (m**2 - m + d) # helper + e = m**2 - m + d # helper # find ml estimates for a and b - mle = minimize(loglike_betabinom, x0=[-m * e / d, ((m - 1) * e) / d], bounds=((1e-6, None), (1e-6, None)), - args=(xi, ni), options={'maxiter': 250}, method='L-BFGS-B', jac=True) + mle = minimize( + loglike_betabinom, + x0=[-m * e / d, ((m - 1) * e) / d], + bounds=((1e-6, None), (1e-6, None)), + args=(xi, ni), + options={"maxiter": 250}, + method="L-BFGS-B", + jac=True, + ) a, b = params = mle.x - params_alt = (a / (a + b), a * b / ((a + b)**2 * (a + b + 1))) # get alternative parametrization (mu and disp) + params_alt = ( + a / (a + b), + a * b / ((a + b) ** 2 * (a + b + 1)), + ) # get alternative parametrization (mu and disp) # mle = minimize(loglike_betabinom2, x0=[-d/(m*e),d/((m-1)*e)],bounds=((1e-9,None),(1e-9,None)), # args=(xi,ni),options={'maxiter': 250}, method='L-BFGS-B', tol=1e-6) # params=([1/p for p in mle.x]) - params_alt = (a / (a + b), a * b / ((a + b)**2 * (a + b + 1))) # get alternative parametrization (mu and disp) + params_alt = ( + a / (a + b), + a * b / ((a + b) ** 2 * (a + b + 1)), + ) # get alternative parametrization (mu and disp) if not mle.success: # should not happen to often, mainly with mu close to boundaries - logger.debug(f'no convergence in betabinomial fit: k={xi}\nn={ni}\nparams={params}\nmessage={mle.message}') - success = False # prevent calculation of p-values based on non optimal parameters + logger.debug( + f"no convergence in betabinomial fit: k={xi}\nn={ni}\nparams={params}\nmessage={mle.message}" + ) + success = ( + False # prevent calculation of p-values based on non optimal parameters + ) return params, params_alt, success def betabinom_lr_test(x, n): - ''' Likelihood ratio test with random-effects betabinomial model. + """Likelihood ratio test with random-effects betabinomial model. This test modles x as betabinomial(n,a,b), eg a binomial distribution, where p follows beta ditribution with parameters a,b>0 mean m=a/(a+b) overdispersion d=ab/((a+b+1)(a+b)^2) --> a=-m(m^2-m+d)/d b=(m-1)(m^2-m+d)/d principle: log likelihood ratio of M0/M1 is chi2 distributed :param x: coverage of the alternative for the two sample groups - :param n: total coverage for the two sample groups''' + :param n: total coverage for the two sample groups""" if any(ni.sum() == 0 for ni in n): - return (np.nan, [None, None]) # one group is not covered at all - no test possible. Checking this to avoid RuntimeWarnings (Mean of empty slice) + return ( + np.nan, + [None, None], + ) # one group is not covered at all - no test possible. Checking this to avoid RuntimeWarnings (Mean of empty slice) x_all, n_all = (np.concatenate(x), np.concatenate(n)) # calculate ml parameters ml_1 = betabinom_ml(x[0], n[0]) @@ -118,11 +158,18 @@ def betabinom_lr_test(x, n): return np.nan, list(ml_1[1] + ml_2[1] + ml_all[1]) try: l0 = betabinom_ll(x_all, n_all, *ml_all[0]).sum() - l1 = betabinom_ll(x[0], n[0], *ml_1[0]).sum() + betabinom_ll(x[1], n[1], *ml_2[0]).sum() + l1 = ( + betabinom_ll(x[0], n[0], *ml_1[0]).sum() + + betabinom_ll(x[1], n[1], *ml_2[0]).sum() + ) except (ValueError, TypeError): - logger.critical(f'betabinom error: x={x}\nn={n}\nparams={ml_1[0]}/{ml_2[0]}/{ml_all[0]}') # should not happen + logger.critical( + f"betabinom error: x={x}\nn={n}\nparams={ml_1[0]}/{ml_2[0]}/{ml_all[0]}" + ) # should not happen raise - return chi2.sf(2 * (l1 - l0), 2), list(ml_1[1] + ml_2[1] + ml_all[1]) # note that we need two degrees of freedom here as h0 hsa two parameters, h1 has 4 + return chi2.sf(2 * (l1 - l0), 2), list( + ml_1[1] + ml_2[1] + ml_all[1] + ) # note that we need two degrees of freedom here as h0 hsa two parameters, h1 has 4 def betabinom_ll(x, n, a, b): @@ -132,36 +179,67 @@ def betabinom_ll(x, n, a, b): return betabinom.logpmf(x, n, a, b).sum() -TESTS = {'betabinom_lr': betabinom_lr_test, - 'binom_lr': binom_lr_test, - 'proportions': proportion_test} +TESTS = { + "betabinom_lr": betabinom_lr_test, + "binom_lr": binom_lr_test, + "proportions": proportion_test, +} -def _check_groups(transcriptome: 'Transcriptome', groups, n_groups=2): - assert len(groups) == n_groups, f"length of groups should be {n_groups}, but found {len(groups)}" +def _check_groups(transcriptome: "Transcriptome", groups, n_groups=2): + assert ( + len(groups) == n_groups + ), f"length of groups should be {n_groups}, but found {len(groups)}" # find groups and sample indices if isinstance(groups, dict): groupnames = list(groups) groups = list(groups.values()) - elif all(isinstance(groupname, str) and groupname in transcriptome.groups() for groupname in groups): + elif all( + isinstance(groupname, str) and groupname in transcriptome.groups() + for groupname in groups + ): groupnames = list(groups) groups = [transcriptome.groups()[gn] for gn in groupnames] elif all(isinstance(group, list) for group in groups): - groupnames = [f'group{i+1}' for i in range(len(groups))] + groupnames = [f"group{i+1}" for i in range(len(groups))] else: - raise ValueError('groups not found in dataset (samples must be a str, list or dict)') - notfound = [sample for group in groups for sample in group if sample not in transcriptome.samples] + raise ValueError( + "groups not found in dataset (samples must be a str, list or dict)" + ) + notfound = [ + sample + for group in groups + for sample in group + if sample not in transcriptome.samples + ] if notfound: raise ValueError(f"Cannot find the following samples: {notfound}") - assert all((groupname1 not in groupname2 for groupname1, groupname2 in itertools.permutations(groupnames, 2))), 'group names must not be contained in other group names' - sample_idx = {sample: idx for sample, idx in transcriptome._get_sample_idx().items()} + assert all( + ( + groupname1 not in groupname2 + for groupname1, groupname2 in itertools.permutations(groupnames, 2) + ) + ), "group names must not be contained in other group names" + sample_idx = { + sample: idx for sample, idx in transcriptome._get_sample_idx().items() + } grp_idx = [[sample_idx[sample] for sample in group] for group in groups] return groupnames, groups, grp_idx -def altsplice_test(self: 'Transcriptome', groups, min_total=100, min_alt_fraction=.1, min_n=10, min_sa=.51, test='auto', padj_method='fdr_bh', - types: Optional[list[ASEType]] = None, **kwargs): - '''Performs the alternative splicing event test. +def altsplice_test( + self: "Transcriptome", + groups, + min_total=100, + min_alt_fraction=0.1, + min_n=10, + min_sa=0.51, + test="auto", + padj_method="fdr_bh", + types: Optional[list[ASEType]] = None, + **kwargs, +): + """Performs the alternative splicing event test. :param groups: Dict with group names as keys and lists of sample names as values, defining the two groups for the test. If more then two groups are provided, test is performed between first two groups, but maximum likelihood parameters @@ -173,24 +251,32 @@ def altsplice_test(self: 'Transcriptome', groups, min_total=100, min_alt_fractio :param test: The name of one of the implemented statistical tests ('betabinom_lr','binom_lr','proportions'). :param padj_method: Specify the method for multiple testing correction. :param types: Restrict the analysis on types of events. If omitted, all types are tested. - :param kwargs: Additional keyword arguments are passed to iter_genes.''' + :param kwargs: Additional keyword arguments are passed to iter_genes.""" - noORF = (None, None, {'NMD': True}) + noORF = (None, None, {"NMD": True}) groupnames, groups, group_idx = _check_groups(self, groups) sidx = np.array(group_idx[0] + group_idx[1]) if isinstance(test, str): - if test == 'auto': - test = 'betabinom_lr' if min(len(group) for group in groups[:2]) > 1 else 'proportions' + if test == "auto": + test = ( + "betabinom_lr" + if min(len(group) for group in groups[:2]) > 1 + else "proportions" + ) test_name = test try: test = TESTS[test] except KeyError as e: - raise ValueError(f'test must be one of {str(list(TESTS))}') from e + raise ValueError(f"test must be one of {str(list(TESTS))}") from e else: - test_name = 'custom' + test_name = "custom" - logger.info('testing differential splicing for %s using %s test', ' vs '.join(f'{groupnames[i]} ({len(groups[i])})' for i in range(2)), test_name) + logger.info( + "testing differential splicing for %s using %s test", + " vs ".join(f"{groupnames[i]} ({len(groups[i])})" for i in range(2)), + test_name, + ) if min_sa < 1: min_sa *= sum(len(group) for group in groups[:2]) @@ -225,8 +311,10 @@ def altsplice_test(self: 'Transcriptome', groups, min_total=100, min_alt_fractio if sum((ni >= min_n).sum() for ni in n[:2]) < min_sa: continue pval, params = test(x[:2], n[:2]) - params_other = tuple(v for xi, ni in zip(x[2:], n[2:]) for v in betabinom_ml(xi, ni)[1]) - if splice_type in ['TSS', 'PAS']: + params_other = tuple( + v for xi, ni in zip(x[2:], n[2:]) for v in betabinom_ml(xi, ni)[1] + ) + if splice_type in ["TSS", "PAS"]: start, end = sg[nX].start, sg[nY].end if (splice_type == "TSS") == (gene.strand == "+"): novel = end not in known.get(splice_type, set()) @@ -236,32 +324,98 @@ def altsplice_test(self: 'Transcriptome', groups, min_total=100, min_alt_fractio start, end = sg[nX].end, sg[nY].start novel = (start, end) not in known.get(splice_type, set()) - nmdA = sum(gene.coverage[np.ix_(sidx, [transcript_id])].sum(None) - for transcript_id in setA if gene.transcripts[transcript_id].get('ORF', noORF)[2]['NMD'])/gene.coverage[np.ix_(sidx, setA)].sum(None) - nmdB = sum(gene.coverage[np.ix_(sidx, [transcript_id])].sum(None) - for transcript_id in setB if gene.transcripts[transcript_id].get('ORF', noORF)[2]['NMD'])/gene.coverage[np.ix_(sidx, setB)].sum(None) - res.append(tuple(itertools.chain((gene.name, gene.id, gene.chrom, gene.strand, start, end, splice_type, novel, pval, - sorted(setA, key=lambda x: -gene.coverage[np.ix_(sidx, [x])].sum(0)), - sorted(setB, key=lambda x: -gene.coverage[np.ix_(sidx, [x])].sum(0)), nmdA, nmdB), - params, params_other, - (val for lists in zip(x, n) for pair in zip(*lists) for val in pair)))) - colnames = ['gene', 'gene_id', 'chrom', 'strand', 'start', 'end', 'splice_type', 'novel', 'pvalue', 'trA', 'trB', 'nmdA', 'nmdB'] - colnames += [groupname + part for groupname in groupnames[:2] + ['total'] + groupnames[2:] for part in ['_PSI', '_disp']] - colnames += [f'{sample}_{groupname}_{w}' for groupname, group in zip(groupnames, groups) for sample in group for w in ['in_cov', 'total_cov']] + nmdA = sum( + gene.coverage[np.ix_(sidx, [transcript_id])].sum(None) + for transcript_id in setA + if gene.transcripts[transcript_id].get("ORF", noORF)[2]["NMD"] + ) / gene.coverage[np.ix_(sidx, setA)].sum(None) + nmdB = sum( + gene.coverage[np.ix_(sidx, [transcript_id])].sum(None) + for transcript_id in setB + if gene.transcripts[transcript_id].get("ORF", noORF)[2]["NMD"] + ) / gene.coverage[np.ix_(sidx, setB)].sum(None) + res.append( + tuple( + itertools.chain( + ( + gene.name, + gene.id, + gene.chrom, + gene.strand, + start, + end, + splice_type, + novel, + pval, + sorted( + setA, + key=lambda x: -gene.coverage[np.ix_(sidx, [x])].sum(0), + ), + sorted( + setB, + key=lambda x: -gene.coverage[np.ix_(sidx, [x])].sum(0), + ), + nmdA, + nmdB, + ), + params, + params_other, + ( + val + for lists in zip(x, n) + for pair in zip(*lists) + for val in pair + ), + ) + ) + ) + colnames = [ + "gene", + "gene_id", + "chrom", + "strand", + "start", + "end", + "splice_type", + "novel", + "pvalue", + "trA", + "trB", + "nmdA", + "nmdB", + ] + colnames += [ + groupname + part + for groupname in groupnames[:2] + ["total"] + groupnames[2:] + for part in ["_PSI", "_disp"] + ] + colnames += [ + f"{sample}_{groupname}_{w}" + for groupname, group in zip(groupnames, groups) + for sample in group + for w in ["in_cov", "total_cov"] + ] df = pd.DataFrame(res, columns=colnames) try: - mask = np.isfinite(df['pvalue']) + mask = np.isfinite(df["pvalue"]) padj = np.empty(mask.shape) padj.fill(np.nan) - padj[mask] = multi.multipletests(df.loc[mask, 'pvalue'], method=padj_method)[1] - df.insert(8, 'padj', padj) + padj[mask] = multi.multipletests(df.loc[mask, "pvalue"], method=padj_method)[1] + df.insert(8, "padj", padj) except TypeError as e: # apparently this happens if df is empty... - logger.error(f'unexpected error during calculation of adjusted p-values: {e}') + logger.error(f"unexpected error during calculation of adjusted p-values: {e}") return df -def die_test(self: 'Transcriptome', groups, min_cov=25, n_isoforms=10, padj_method='fdr_bh', **kwargs): - ''' Reimplementation of the DIE test, suggested by Joglekar et al in Nat Commun 12, 463 (2021): +def die_test( + self: "Transcriptome", + groups, + min_cov=25, + n_isoforms=10, + padj_method="fdr_bh", + **kwargs, +): + """Reimplementation of the DIE test, suggested by Joglekar et al in Nat Commun 12, 463 (2021): "A spatially resolved brain region- and cell type-specific isoform atlas of the postnatal mouse brain" Syntax and parameters follow the original implementation in https://github.com/noush-joglekar/scisorseqr/blob/master/inst/RScript/IsoformTest.R @@ -269,24 +423,45 @@ def die_test(self: 'Transcriptome', groups, min_cov=25, n_isoforms=10, padj_meth :param groups: Dict with group names as keys and lists of sample names as values, defining the two groups for the test. :param min_cov: Minimal number of reads per group for each gene. :param n_isoforms: Number of isoforms to consider in the test for each gene. All additional least expressed isoforms get summarized. - :param kwargs: Additional keyword arguments are passed to iter_genes.''' + :param kwargs: Additional keyword arguments are passed to iter_genes.""" groupnames, groups, grp_idx = _check_groups(self, groups) - logger.info('testing differential isoform expression (DIE) for %s.', ' vs '.join(f'{groupnames[i]} ({len(groups[i])})' for i in range(2))) - - result = [(gene.id, gene.name, gene.chrom, gene.strand, gene.start, gene.end) + - gene.die_test(grp_idx, min_cov, n_isoforms) for gene in self.iter_genes(**kwargs)] - result = pd.DataFrame(result, columns=['gene_id', 'gene_name', 'chrom', 'strand', 'start', 'end', 'pvalue', 'deltaPI', 'transcript_ids']) - mask = np.isfinite(result['pvalue']) + logger.info( + "testing differential isoform expression (DIE) for %s.", + " vs ".join(f"{groupnames[i]} ({len(groups[i])})" for i in range(2)), + ) + + result = [ + (gene.id, gene.name, gene.chrom, gene.strand, gene.start, gene.end) + + gene.die_test(grp_idx, min_cov, n_isoforms) + for gene in self.iter_genes(**kwargs) + ] + result = pd.DataFrame( + result, + columns=[ + "gene_id", + "gene_name", + "chrom", + "strand", + "start", + "end", + "pvalue", + "deltaPI", + "transcript_ids", + ], + ) + mask = np.isfinite(result["pvalue"]) padj = np.empty(mask.shape) padj.fill(np.nan) - padj[mask] = multi.multipletests(result.loc[mask, 'pvalue'], method=padj_method)[1] - result.insert(6, 'padj', padj) + padj[mask] = multi.multipletests(result.loc[mask, "pvalue"], method=padj_method)[1] + result.insert(6, "padj", padj) return result -def alternative_splicing_events(self, min_total=100, min_alt_fraction=.1, samples=None, **kwargs): - '''Finds alternative splicing events. +def alternative_splicing_events( + self, min_total=100, min_alt_fraction=0.1, samples=None, **kwargs +): + """Finds alternative splicing events. Finds alternative splicing events and potential transcription start sites/polyA sites by searching for splice bubbles in the Segment Graph. @@ -296,15 +471,15 @@ def alternative_splicing_events(self, min_total=100, min_alt_fraction=.1, sample :param min_alt_fraction: Minimum fraction of reads supporting the alternative. :param samples: Specify the samples to consider. If omitted, all samples are selected. :param kwargs: Additional keyword arguments are passed to iter_genes. - :return: Table with alternative splicing events.''' + :return: Table with alternative splicing events.""" bubbles = [] if samples is None: samples = self.samples - assert all(s in self.samples for s in samples), 'not all specified samples found' + assert all(s in self.samples for s in samples), "not all specified samples found" sample_dict = {sample: i for i, sample in enumerate(self.samples)} sidx = np.array([sample_dict[sample] for sample in samples]) - assert 0 < min_alt_fraction < .5, 'min_alt_fraction must be > 0 and < 0.5' + assert 0 < min_alt_fraction < 0.5, "min_alt_fraction must be > 0 and < 0.5" for gene in self.iter_genes(**kwargs): if gene.coverage[sidx, :].sum() < min_total: continue @@ -315,17 +490,28 @@ def alternative_splicing_events(self, min_total=100, min_alt_fraction=.1, sample for _, _, nX, nY, splice_type in ref_seg_graph.find_splice_bubbles(): if splice_type in ("TSS", "PAS"): if (splice_type == "TSS") == (gene.strand == "+"): - known.setdefault(splice_type, set()).add((ref_seg_graph[nX].end)) + known.setdefault(splice_type, set()).add( + (ref_seg_graph[nX].end) + ) else: - known.setdefault(splice_type, set()).add((ref_seg_graph[nY].start)) + known.setdefault(splice_type, set()).add( + (ref_seg_graph[nY].start) + ) else: - known.setdefault(splice_type, set()).add((ref_seg_graph[nX].end, ref_seg_graph[nY].start)) + known.setdefault(splice_type, set()).add( + (ref_seg_graph[nX].end, ref_seg_graph[nY].start) + ) seg_graph: SegmentGraph = gene.segment_graph for setA, setB, nX, nY, splice_type in seg_graph.find_splice_bubbles(): junction_cov = gene.coverage[np.ix_(sidx, setA)].sum(1) total_cov = gene.coverage[np.ix_(sidx, setB)].sum(1) + junction_cov - if total_cov.sum() >= min_total and min_alt_fraction < junction_cov.sum() / total_cov.sum() < 1 - min_alt_fraction: - if splice_type in ['TSS', 'PAS']: + if ( + total_cov.sum() >= min_total + and min_alt_fraction + < junction_cov.sum() / total_cov.sum() + < 1 - min_alt_fraction + ): + if splice_type in ["TSS", "PAS"]: start, end = seg_graph[nX].start, seg_graph[nY].end if (splice_type == "TSS") == (gene.strand == "+"): novel = end not in known.get(splice_type, set()) @@ -334,21 +520,31 @@ def alternative_splicing_events(self, min_total=100, min_alt_fraction=.1, sample else: start, end = seg_graph[nX].end, seg_graph[nY].start novel = (start, end) not in known.get(splice_type, set()) - bubbles.append([gene.id, gene.chrom, start, end, splice_type, novel] + list(junction_cov) + list(total_cov)) - return pd.DataFrame(bubbles, columns=['gene', 'chr', 'start', 'end', 'splice_type', 'novel'] + - [f'{sample}_{what}' for what in ['in_cov', 'total_cov'] for sample in samples]) + bubbles.append( + [gene.id, gene.chrom, start, end, splice_type, novel] + + list(junction_cov) + + list(total_cov) + ) + return pd.DataFrame( + bubbles, + columns=["gene", "chr", "start", "end", "splice_type", "novel"] + + [ + f"{sample}_{what}" for what in ["in_cov", "total_cov"] for sample in samples + ], + ) + # summary tables (can be used as input to plot_bar / plot_dist) # function to optimize (inverse nbinom cdf) -def _tpm_fun(tpm_th, n_reads, cov_th=2, p=.8): - return (p-nbinom.cdf(n_reads - cov_th, n=cov_th, p=tpm_th * 1e-6))**2 +def _tpm_fun(tpm_th, n_reads, cov_th=2, p=0.8): + return (p - nbinom.cdf(n_reads - cov_th, n=cov_th, p=tpm_th * 1e-6)) ** 2 -def estimate_tpm_threshold(n_reads, cov_th=2, p=.8): - '''Estimate the minimum expression level of observable transcripts at given coverage. +def estimate_tpm_threshold(n_reads, cov_th=2, p=0.8): + """Estimate the minimum expression level of observable transcripts at given coverage. The function returns the expression level in transcripts per million (TPM), that can be observed at the given sequencing depth. @@ -356,12 +552,20 @@ def estimate_tpm_threshold(n_reads, cov_th=2, p=.8): :param n_reads: The sequencing depth (total number of reads) for the sample. :param cov_th: The requested minimum number of reads per transcripts. :param p: The probability of a transcript at threshold expression level to be observed. - ''' - return minimize_scalar(_tpm_fun, bounds=(.01, 1000), args=(n_reads, cov_th, p))['x'] + """ + return minimize_scalar(_tpm_fun, bounds=(0.01, 1000), args=(n_reads, cov_th, p))[ + "x" + ] -def altsplice_stats(self: 'Transcriptome', groups=None, weight_by_coverage=True, min_coverage=2, tr_filter={}): - '''Summary statistics for novel alternative splicing. +def altsplice_stats( + self: "Transcriptome", + groups=None, + weight_by_coverage=True, + min_coverage=2, + tr_filter=None, +): + """Summary statistics for novel alternative splicing. This function counts the novel alternative splicing events of LRTS isoforms with respect to the reference annotation. The result can be depicted by isotools.plots.plot_bar. @@ -370,7 +574,11 @@ def altsplice_stats(self: 'Transcriptome', groups=None, weight_by_coverage=True, :param weight_by_coverage: If True, each transcript is weighted by the coverage. :param min_coverage: Threshold to ignore poorly covered transcripts. This parameter gets applied for each sample group separately. :param tr_filter: Filter dict, that is passed to self.iter_transcripts(). - :return: Table with numbers of novel alternative splicing events, and suggested parameters for isotools.plots.plot_bar().''' + :return: Table with numbers of novel alternative splicing events, and suggested parameters for isotools.plots.plot_bar(). + """ + if tr_filter is None: + tr_filter = {} + weights = dict() # if groups is not None: # gi={r:i for i,r in enumerate(runs)} @@ -378,38 +586,57 @@ def altsplice_stats(self: 'Transcriptome', groups=None, weight_by_coverage=True, current = None if groups is not None: sample_idx = {sample: i for i, sample in enumerate(self.samples)} # idx - groups = {groupname: [sample_idx[sample] for sample in group] for groupname, group in groups.items()} + groups = { + groupname: [sample_idx[sample] for sample in group] + for groupname, group in groups.items() + } for gene, transcript_id, transcript in self.iter_transcripts(**tr_filter): if gene != current: current = gene - w = gene.coverage.copy() if groups is None else np.array([gene.coverage[grp, :].sum(0) for grp in groups.values()]) + w = ( + gene.coverage.copy() + if groups is None + else np.array([gene.coverage[grp, :].sum(0) for grp in groups.values()]) + ) w[w < min_coverage] = 0 if not weight_by_coverage: w[w > 0] = 1 - if 'annotation' not in transcript or transcript['annotation'] is None: - weights['unknown'] = weights.get('unknown', np.zeros(w.shape[0])) + w[:, transcript_id] + if "annotation" not in transcript or transcript["annotation"] is None: + weights["unknown"] = ( + weights.get("unknown", np.zeros(w.shape[0])) + w[:, transcript_id] + ) else: - for stype in transcript['annotation'][1]: - weights[stype] = weights.get(stype, np.zeros(w.shape[0])) + w[:, transcript_id] - weights['total'] = weights.get('total', np.zeros(w.shape[0])) + w[:, transcript_id] - - df = pd.DataFrame(weights, index=self.samples if groups is None else groups.keys()).T - df = df.reindex(df.mean(1).sort_values(ascending=False).index, axis=0) # sort by row mean + for stype in transcript["annotation"][1]: + weights[stype] = ( + weights.get(stype, np.zeros(w.shape[0])) + w[:, transcript_id] + ) + weights["total"] = ( + weights.get("total", np.zeros(w.shape[0])) + w[:, transcript_id] + ) + + df = pd.DataFrame( + weights, index=self.samples if groups is None else groups.keys() + ).T + df = df.reindex( + df.mean(1).sort_values(ascending=False).index, axis=0 + ) # sort by row mean if weight_by_coverage: - title = 'Expressed Transcripts' - ylab = 'fraction of reads' + title = "Expressed Transcripts" + ylab = "fraction of reads" else: - title = 'Different Transcripts' - ylab = 'fraction of different transcripts' + title = "Different Transcripts" + ylab = "fraction of different transcripts" if min_coverage > 1: - title += f' > {min_coverage} reads' + title += f" > {min_coverage} reads" - return df, {'ylabel': ylab, 'title': title} + return df, {"ylabel": ylab, "title": title} -def _check_customised_groups(transcriptome: 'Transcriptome', samples=None, groups=None, sample_idx=False): - ''' +def _check_customised_groups( + transcriptome: "Transcriptome", samples=None, groups=None, sample_idx=False +): + """ Check if the samples and all the samples in groups are consistent, and all found in transcriptome.samples. Customised group names not in transcriptome.groups() are allowed. @@ -417,47 +644,80 @@ def _check_customised_groups(transcriptome: 'Transcriptome', samples=None, group :param groups: A dict {group_name:[sample_name_list]} or a list of group names to tell how to group samples. If omitted, all the samples are considered as one group. :param sample_idx: If True, the samples are specified by sample indices. If False, the samples are specified by sample names. :return: A dict {group_name:[sample_list]} with sample names or indices. - ''' + """ if samples is None: samples = transcriptome.samples else: - assert all(s in transcriptome.samples for s in samples), 'not all specified samples found' + assert all( + s in transcriptome.samples for s in samples + ), "not all specified samples found" if isinstance(groups, dict): - assert list(set(sum(groups.values(), []))) == list(set(samples)), 'inconsistent samples specified in samples and in groups' + assert list(set(sum(groups.values(), []))) == list( + set(samples) + ), "inconsistent samples specified in samples and in groups" if groups is None: - group_dict = {'all' if len(samples) == len(transcriptome.samples) else 'selected': samples} + group_dict = { + "all" if len(samples) == len(transcriptome.samples) else "selected": samples + } elif isinstance(groups, dict): - assert all(s in samples for s in sum(groups.values(), [])), 'not all the samples in specified groups are found' + assert all( + s in samples for s in sum(groups.values(), []) + ), "not all the samples in specified groups are found" group_dict = groups elif isinstance(groups, list): - assert all(gn in transcriptome.groups() for gn in groups), 'not all specified groups are found. To customize groups, use a dict {group_name:[sample_name_list]}' - group_dict = {gn: [s for s in transcriptome.groups()[gn] if s in samples] for gn in groups if any(s in samples for s in transcriptome.groups()[gn])} + assert all( + gn in transcriptome.groups() for gn in groups + ), "not all specified groups are found. To customize groups, use a dict {group_name:[sample_name_list]}" + group_dict = { + gn: [s for s in transcriptome.groups()[gn] if s in samples] + for gn in groups + if any(s in samples for s in transcriptome.groups()[gn]) + } else: - raise ValueError('groups must be a dict or a list of group names') + raise ValueError("groups must be a dict or a list of group names") if sample_idx: - group_dict = {gn:[transcriptome.samples.index(s) for s in sample_names] for gn, sample_names in group_dict.items()} + group_dict = { + gn: [transcriptome.samples.index(s) for s in sample_names] + for gn, sample_names in group_dict.items() + } return group_dict -def entropy_calculation(self: 'Transcriptome', samples=None, groups=None, min_total=1, relative=True, **kwargs): - ''' +def entropy_calculation( + self: "Transcriptome", + samples=None, + groups=None, + min_total=1, + relative=True, + **kwargs, +): + """ Calculates the entropy of genes based on the coverage of selected transcripts. :param samples: A list of sample names to specify the samples to be considered. If omitted, all samples are selected. - :param groups: Entropy calculation done by groups of samples. A dict {group_name:[sample_name_list]} or a list of group names. If omitted, all the samples are considered as one group. + :param groups: Entropy calculation done by groups of samples. A dict {group_name:[sample_name_list]} or a list of group names. + If omitted, all the samples are considered as one group. :param min_total: Minimum total coverage of a gene over all the samples in a selected group. :param relative: If True, the entropy is normalized by log2 of the number of selected transcripts in the group. :param kwargs: Additional keyword arguments are passed to iter_transcripts. :return: A table of (relative) entropy of genes based on the coverage of selected transcripts. - ''' + """ group_idxs = _check_customised_groups(self, samples, groups, sample_idx=True) - entropy_tab = pd.DataFrame(columns=['gene_id', 'gene_name'] + [f'{g}_{c}' for g, c in itertools.product(group_idxs, ['ntr', 'rel_entropy' if relative else 'entropy'])]) + entropy_tab = pd.DataFrame( + columns=["gene_id", "gene_name"] + + [ + f"{g}_{c}" + for g, c in itertools.product( + group_idxs, ["ntr", "rel_entropy" if relative else "entropy"] + ) + ] + ) for gene, transcript_ids, _ in self.iter_transcripts(genewise=True, **kwargs): gene_entropy = [gene.id, gene.name] @@ -468,62 +728,108 @@ def entropy_calculation(self: 'Transcriptome', samples=None, groups=None, min_to gene_entropy += [np.nan, np.nan] else: transcript_number = sum(cov.sum(0) > 0) - group_entropy = -sum(math.log2(p) * p for p in cov.sum(0)[cov.sum(0) > 0] / cov.sum()) + group_entropy = -sum( + math.log2(p) * p for p in cov.sum(0)[cov.sum(0) > 0] / cov.sum() + ) if relative: - group_entropy = (group_entropy / math.log2(transcript_number)) if transcript_number > 1 else np.nan + group_entropy = ( + (group_entropy / math.log2(transcript_number)) + if transcript_number > 1 + else np.nan + ) gene_entropy += [transcript_number, group_entropy] - entropy_tab = pd.concat([entropy_tab, pd.DataFrame([gene_entropy], columns=entropy_tab.columns)], ignore_index=True) + entropy_tab = pd.concat( + [entropy_tab, pd.DataFrame([gene_entropy], columns=entropy_tab.columns)], + ignore_index=True, + ) # exclude rows with all empty or NA entries in entropy columns - entropy_tab.dropna(subset=entropy_tab.columns[2:], how='all', inplace=True) + entropy_tab.dropna(subset=entropy_tab.columns[2:], how="all", inplace=True) return entropy_tab -def str_var_calculation(self: 'Transcriptome', samples=None, groups=None, strict_ec=0, strict_pos=15, count_number=False, **kwargs): - ''' +def str_var_calculation( + self: "Transcriptome", + samples=None, + groups=None, + strict_ec=0, + strict_pos=15, + count_number=False, + **kwargs, +): + """ Quantify the structural variation of genes based on selected transcripts. Structural variation includes (and in the same order of) distinct TSS positions, exon chains, and PAS positions. - :param samples: A list of sample names to specify the samples to be considered. If omitted, all samples are selected. - :param groups: Quantification done by groups of samples. A dict {group_name:[sample_name_list]} or a list of group names. If omitted, all the samples are considered as one group. + :param samples: A list of sample names to specify the samples to be considered. + If omitted, all samples are selected. + :param groups: Quantification done by groups of samples. A dict {group_name:[sample_name_list]} or a list of group names. + If omitted, all the samples are considered as one group. :param strict_ec: Distance allowed between each position, except for the first/last, in two exon chains so that they can be considered as identical. :param strict_pos: Difference allowed between two positions when considering identical TSS/PAS. :param count_number: By default False. If True, the number of distinct TSSs, exon chains and PASs in genes directly. :param kwargs: Additional keyword arguments are passed to iter_transcripts. - :return: A table of structural variation of genes based on selected transcripts, including: gene_id, gene_name, and the variation of TSS, exon chain, and PAS for each group of samples. - ''' + :return: A table of structural variation of genes based on selected transcripts, + including: gene_id, gene_name, and the variation of TSS, exon chain, and PAS for each group of samples. + """ group_sns = _check_customised_groups(self, samples, groups, sample_idx=False) - str_var_tab = pd.DataFrame(columns=['gene_id', 'gene_name'] + [f'{g}_{c}' for g, c in itertools.product(group_sns, ['tss', 'ec', 'pas'])]) + str_var_tab = pd.DataFrame( + columns=["gene_id", "gene_name"] + + [f"{g}_{c}" for g, c in itertools.product(group_sns, ["tss", "ec", "pas"])] + ) for gene, _, selected_trs in self.iter_transcripts(genewise=True, **kwargs): gene_str_var = [gene.id, gene.name] for _, selected_samples in group_sns.items(): - group_var = str_var_triplet(transcripts=selected_trs, samples=selected_samples, strict_ec=strict_ec, strict_pos=strict_pos) + group_var = str_var_triplet( + transcripts=selected_trs, + samples=selected_samples, + strict_ec=strict_ec, + strict_pos=strict_pos, + ) if not count_number: # regress out the variation caused by TAS and PAS for exon chain - splicing_ratio = 2 * group_var[1] / (group_var[0] + group_var[2]) if (group_var[0] > 0 and group_var[2] > 0) else 0 + splicing_ratio = ( + 2 * group_var[1] / (group_var[0] + group_var[2]) + if (group_var[0] > 0 and group_var[2] > 0) + else 0 + ) ratio_triplet = [group_var[0], splicing_ratio, group_var[2]] # normalize to the sum of 1 - group_var = [n / sum(ratio_triplet) for n in ratio_triplet] if sum(ratio_triplet) > 0 else [0, 0, 0] + group_var = ( + [n / sum(ratio_triplet) for n in ratio_triplet] + if sum(ratio_triplet) > 0 + else [0, 0, 0] + ) gene_str_var += group_var - str_var_tab = pd.concat([str_var_tab, pd.DataFrame([gene_str_var], columns=str_var_tab.columns)], ignore_index=True) + str_var_tab = pd.concat( + [str_var_tab, pd.DataFrame([gene_str_var], columns=str_var_tab.columns)], + ignore_index=True, + ) # replace 0 with nan, and remove rows with all nan str_var_tab = str_var_tab.replace(0, np.nan) - str_var_tab = str_var_tab.dropna(how='all', subset=str_var_tab.columns[2:]) + str_var_tab = str_var_tab.dropna(how="all", subset=str_var_tab.columns[2:]) return str_var_tab -def filter_stats(self: 'Transcriptome', tags=None, groups=None, weight_by_coverage=True, min_coverage=2, **kwargs): - '''Summary statistics for filter flags. +def filter_stats( + self: "Transcriptome", + tags=None, + groups=None, + weight_by_coverage=True, + min_coverage=2, + **kwargs, +): + """Summary statistics for filter flags. This function counts the number of transcripts corresponding to filter tags. The result can be depicted by isotools.plots.plot_bar. @@ -533,48 +839,90 @@ def filter_stats(self: 'Transcriptome', tags=None, groups=None, weight_by_covera :param weight_by_coverage: If True, each transcript is weighted by the number of supporting reads. :param min_coverage: Coverage threshold per sample to ignore poorly covered transcripts. :param kwargs: Additional parameters are passed to self.iter_transcripts(). - :return: Table with numbers of transcripts featuring the filter tag, and suggested parameters for isotools.plots.plot_bar().''' + :return: Table with numbers of transcripts featuring the filter tag, and suggested parameters for isotools.plots.plot_bar(). + """ weights = dict() if tags is None: - tags = list(self.filter['transcript']) - assert all(t in self.filter['transcript'] for t in tags), '"Tags" contains invalid tags' - filterfun = {tag: _filter_function(tag, self.filter['transcript'])[0] for tag in tags} + tags = list(self.filter["transcript"]) + assert all( + t in self.filter["transcript"] for t in tags + ), '"Tags" contains invalid tags' + filterfun = { + tag: _filter_function(tag, self.filter["transcript"])[0] for tag in tags + } if groups is not None: sample_indices = {sample: i for i, sample in enumerate(self.samples)} # idx - groups = {group_name: [sample_indices[sample] for sample in sample_group] for group_name, sample_group in groups.items()} + groups = { + group_name: [sample_indices[sample] for sample in sample_group] + for group_name, sample_group in groups.items() + } current = None for gene, transcript_id, transcript in self.iter_transcripts(**kwargs): if gene != current: current = gene - weight = gene.coverage.copy() if groups is None else np.array([gene.coverage[group, :].sum(0) for group in groups.values()]) + weight = ( + gene.coverage.copy() + if groups is None + else np.array( + [gene.coverage[group, :].sum(0) for group in groups.values()] + ) + ) weight[weight < min_coverage] = 0 if not weight_by_coverage: weight[weight > 0] = 1 # relevant_filter=[filter for filter in transcript['filter'] if consider is None or filter in consider] - relevant_filter = [tag for tag in tags if filterfun[tag](gene=gene, trid=transcript_id, **transcript)] + relevant_filter = [ + tag + for tag in tags + if filterfun[tag](gene=gene, trid=transcript_id, **transcript) + ] for filter in relevant_filter: - weights[filter] = weights.get(filter, np.zeros(weight.shape[0])) + weight[:, transcript_id] + weights[filter] = ( + weights.get(filter, np.zeros(weight.shape[0])) + + weight[:, transcript_id] + ) if not relevant_filter: - weights['PASS'] = weights.get('PASS', np.zeros(weight.shape[0])) + weight[:, transcript_id] - weights['total'] = weights.get('total', np.zeros(weight.shape[0])) + weight[:, transcript_id] + weights["PASS"] = ( + weights.get("PASS", np.zeros(weight.shape[0])) + + weight[:, transcript_id] + ) + weights["total"] = ( + weights.get("total", np.zeros(weight.shape[0])) + weight[:, transcript_id] + ) - df = pd.DataFrame(weights, index=self.samples if groups is None else groups.keys()).T + df = pd.DataFrame( + weights, index=self.samples if groups is None else groups.keys() + ).T df = df.reindex(df.mean(1).sort_values(ascending=False).index, axis=0) - ylab = 'fraction of reads' if weight_by_coverage else 'fraction of different transcripts' + ylab = ( + "fraction of reads" + if weight_by_coverage + else "fraction of different transcripts" + ) if weight_by_coverage: - title = 'Expressed Transcripts' + title = "Expressed Transcripts" else: - title = 'Different Transcripts' + title = "Different Transcripts" if min_coverage > 1: - title += f' > {min_coverage} reads' - return df, {'ylabel': ylab, 'title': title} - - -def transcript_length_hist(self: 'Transcriptome', groups=None, add_reference=False, bins=50, x_range=( - 0, 10000), weight_by_coverage=True, min_coverage=2, use_alignment=True, tr_filter={}, ref_filter={}): - '''Retrieves the transcript length distribution. + title += f" > {min_coverage} reads" + return df, {"ylabel": ylab, "title": title} + + +def transcript_length_hist( + self: "Transcriptome", + groups=None, + add_reference=False, + bins=50, + x_range=(0, 10000), + weight_by_coverage=True, + min_coverage=2, + use_alignment=True, + tr_filter=None, + ref_filter=None, +): + """Retrieves the transcript length distribution. This function counts the number of transcripts within length interval. The result can be depicted by isotools.plots.plot_dist. @@ -588,7 +936,12 @@ def transcript_length_hist(self: 'Transcriptome', groups=None, add_reference=Fal :param use_alignment: use the transcript length as defined by the alignment (e.g. the sum of all exon lengths). :param tr_filter: Filter dict, that is passed to self.iter_transcripts(). :param ref_filter: Filter dict, that is passed to self.iter_ref_transcripts() (relevant only if add_reference=True). - :return: Table with numbers of transcripts within the length intervals, and suggested parameters for isotools.plots.plot_distr().''' + :return: Table with numbers of transcripts within the length intervals, and suggested parameters for isotools.plots.plot_distr(). + """ + if tr_filter is None: + tr_filter = {} + if ref_filter is None: + ref_filter = {} trlen = [] cov = [] @@ -598,26 +951,45 @@ def transcript_length_hist(self: 'Transcriptome', groups=None, add_reference=Fal current = gene current_cov = gene.coverage cov.append(current_cov[:, transcript_id]) - trlen.append(sum(e[1] - e[0] for e in transcript['exons']) if use_alignment else transcript['source_len']) # source_len is not set in the current version + trlen.append( + sum(e[1] - e[0] for e in transcript["exons"]) + if use_alignment + else transcript["source_len"] + ) # source_len is not set in the current version cov = pd.DataFrame(cov, columns=self.samples) if groups is not None: cov = pd.DataFrame({grn: cov[group].sum(1) for grn, group in groups.items()}) if isinstance(bins, int): - bins = np.linspace(x_range[0] - .5, x_range[1] - .5, bins + 1) + bins = np.linspace(x_range[0] - 0.5, x_range[1] - 0.5, bins + 1) cov[cov < min_coverage] = 0 if not weight_by_coverage: cov[cov > 0] = 1 - counts = pd.DataFrame({gn: np.histogram(trlen, weights=g_cov, bins=bins)[0] for gn, g_cov in cov.items()}) + counts = pd.DataFrame( + { + gn: np.histogram(trlen, weights=g_cov, bins=bins)[0] + for gn, g_cov in cov.items() + } + ) if add_reference: - ref_len = [sum(exon[1] - exon[0] for exon in transcript['exons']) for _, _, transcript in self.iter_ref_transcripts(**ref_filter)] - counts['reference'] = np.histogram(ref_len, bins=bins)[0] - bin_df = pd.DataFrame({'from': bins[:-1], 'to': bins[1:]}) - params = dict(yscale='linear', title='transcript length', xlabel='transcript length [bp]', density=True) - return pd.concat([bin_df, counts], axis=1).set_index(['from', 'to']), params - - -def transcript_coverage_hist(self, groups=None, bins=50, x_range=(1, 1001), tr_filter={}): - '''Retrieves the transcript coverage distribution. + ref_len = [ + sum(exon[1] - exon[0] for exon in transcript["exons"]) + for _, _, transcript in self.iter_ref_transcripts(**ref_filter) + ] + counts["reference"] = np.histogram(ref_len, bins=bins)[0] + bin_df = pd.DataFrame({"from": bins[:-1], "to": bins[1:]}) + params = dict( + yscale="linear", + title="transcript length", + xlabel="transcript length [bp]", + density=True, + ) + return pd.concat([bin_df, counts], axis=1).set_index(["from", "to"]), params + + +def transcript_coverage_hist( + self, groups=None, bins=50, x_range=(1, 1001), tr_filter=None +): + """Retrieves the transcript coverage distribution. This function counts the number of transcripts within coverage interval. The result can be depicted by isotools.plots.plot_dist. @@ -626,7 +998,11 @@ def transcript_coverage_hist(self, groups=None, bins=50, x_range=(1, 1001), tr_f :param bins: Define the coverage interval, either by a single number of bins, or by a list of values, defining the interval boundaries. :param x_range: The range of the intervals. Ignored if "bins" is provided as a list. :param tr_filter: Filter dict, that is passed to self.iter_transcripts(). - :return: Table with numbers of transcripts within the coverage intervals, and suggested parameters for isotools.plots.plot_distr().''' + :return: Table with numbers of transcripts within the coverage intervals, and suggested parameters for isotools.plots.plot_distr(). + """ + if tr_filter is None: + tr_filter = {} + # get the transcript coverage in bins for groups # return count dataframe and suggested default parameters for plot_distr cov = [] @@ -640,19 +1016,32 @@ def transcript_coverage_hist(self, groups=None, bins=50, x_range=(1, 1001), tr_f if groups is not None: cov = pd.DataFrame({grn: cov[grp].sum(1) for grn, grp in groups.items()}) if isinstance(bins, int): - bins = np.linspace(x_range[0] - .5, x_range[1] - .5, bins + 1) - counts = pd.DataFrame({gn: np.histogram(g_cov, bins=bins)[0] for gn, g_cov in cov.items()}) - bin_df = pd.DataFrame({'from': bins[:-1], 'to': bins[1:]}) - params = dict(yscale='log', title='transcript coverage', xlabel='reads per transcript') - return pd.concat([bin_df, counts], axis=1).set_index(['from', 'to']), params + bins = np.linspace(x_range[0] - 0.5, x_range[1] - 0.5, bins + 1) + counts = pd.DataFrame( + {gn: np.histogram(g_cov, bins=bins)[0] for gn, g_cov in cov.items()} + ) + bin_df = pd.DataFrame({"from": bins[:-1], "to": bins[1:]}) + params = dict( + yscale="log", title="transcript coverage", xlabel="reads per transcript" + ) + return pd.concat([bin_df, counts], axis=1).set_index(["from", "to"]), params # plot histogram # cov.mask(cov.lt(x_range[0]) | cov.gt(x_range[1])).plot.hist(ax=ax, alpha=0.5, bins=n_bins) # ax=counts.plot.bar() # ax.plot(x, counts) -def transcripts_per_gene_hist(self, groups=None, add_reference=False, bins=49, x_range=(1, 50), min_coverage=2, tr_filter={}, ref_filter={}): - '''Retrieves the histogram of number of transcripts per gene. +def transcripts_per_gene_hist( + self, + groups=None, + add_reference=False, + bins=49, + x_range=(1, 50), + min_coverage=2, + tr_filter=None, + ref_filter=None, +): + """Retrieves the histogram of number of transcripts per gene. This function counts the genes featuring transcript numbers within specified intervals. The result can be depicted by isotools.plots.plot_dist. @@ -665,7 +1054,13 @@ def transcripts_per_gene_hist(self, groups=None, add_reference=False, bins=49, x :param tr_filter: Filter dict, that is passed to self.iter_transcripts(). :param ref_filter: Filter dict, that is passed to self.iter_ref_transcripts() (relevant only if add_reference=True). :return: Table with numbers of genes featuring transcript numbers within the specified intervals, - and suggested parameters for isotools.plots.plot_distr().''' + and suggested parameters for isotools.plots.plot_distr(). + """ + if tr_filter is None: + tr_filter = {} + if ref_filter is None: + ref_filter = {} + ntr = [] current = None if groups is None: @@ -673,35 +1068,55 @@ def transcripts_per_gene_hist(self, groups=None, add_reference=False, bins=49, x else: group_names = groups.keys() sidx = {sample: i for i, sample in enumerate(self.samples)} # idx - groups = {groupname: [sidx[sample] for sample in group] for groupname, group in groups.items()} + groups = { + groupname: [sidx[sample] for sample in group] + for groupname, group in groups.items() + } n_sa = len(group_names) for gene, transcript_id, _ in self.iter_transcripts(**tr_filter): if gene != current: current = gene - current_cov = gene.coverage if groups is None else np.array([gene.coverage[grp, :].sum(0) for grp in groups.values()]) + current_cov = ( + gene.coverage + if groups is None + else np.array([gene.coverage[grp, :].sum(0) for grp in groups.values()]) + ) ntr.append(np.zeros(n_sa)) ntr[-1] += current_cov[:, transcript_id] >= min_coverage ntr = pd.DataFrame((n for n in ntr if n.sum() > 0), columns=group_names) if isinstance(bins, int): - bins = np.linspace(x_range[0] - .5, x_range[1] - .5, bins + 1) + bins = np.linspace(x_range[0] - 0.5, x_range[1] - 0.5, bins + 1) counts = pd.DataFrame({gn: np.histogram(n, bins=bins)[0] for gn, n in ntr.items()}) if add_reference: if ref_filter: - logger.warning('reference filter not implemented') - ref_ntr = [gene.n_ref_transcripts for gene in self] # todo: add reference filter - counts['reference'] = np.histogram(ref_ntr, bins=bins)[0] - bin_df = pd.DataFrame({'from': bins[:-1], 'to': bins[1:]}) - sub = f'counting transcripts covered by >= {min_coverage} reads' - if 'query' in tr_filter: + logger.warning("reference filter not implemented") + ref_ntr = [ + gene.n_ref_transcripts for gene in self + ] # todo: add reference filter + counts["reference"] = np.histogram(ref_ntr, bins=bins)[0] + bin_df = pd.DataFrame({"from": bins[:-1], "to": bins[1:]}) + sub = f"counting transcripts covered by >= {min_coverage} reads" + if "query" in tr_filter: sub += f', filter query: {tr_filter["query"]}' - params = dict(yscale='log', title='transcript per gene\n' + sub, xlabel='transcript per gene') - return pd.concat([bin_df, counts], axis=1).set_index(['from', 'to']), params - - -def exons_per_transcript_hist(self, groups=None, add_reference=False, bins=34, x_range=(1, 69), - weight_by_coverage=True, min_coverage=2, tr_filter={}, ref_filter={}): - '''Retrieves the histogram of number of exons per transcript. + params = dict( + yscale="log", title="transcript per gene\n" + sub, xlabel="transcript per gene" + ) + return pd.concat([bin_df, counts], axis=1).set_index(["from", "to"]), params + + +def exons_per_transcript_hist( + self, + groups=None, + add_reference=False, + bins=34, + x_range=(1, 69), + weight_by_coverage=True, + min_coverage=2, + tr_filter=None, + ref_filter=None, +): + """Retrieves the histogram of number of exons per transcript. This function counts the transcripts featuring exon numbers within specified intervals. The result can be depicted by isotools.plots.plot_dist. @@ -715,7 +1130,13 @@ def exons_per_transcript_hist(self, groups=None, add_reference=False, bins=34, x :param tr_filter: Filter dict, that is passed to self.iter_transcripts(). :param ref_filter: Filter dict, that is passed to self.iter_ref_transcripts() (relevant only if add_reference=True). :return: Table with numbers of transcripts featuring exon numbers within the specified intervals, - and suggested parameters for isotools.plots.plot_distr().''' + and suggested parameters for isotools.plots.plot_distr(). + """ + if tr_filter is None: + tr_filter = {} + if ref_filter is None: + ref_filter = {} + n_exons = [] cov = [] current = None @@ -724,29 +1145,51 @@ def exons_per_transcript_hist(self, groups=None, add_reference=False, bins=34, x current = gene current_cov = gene.coverage cov.append(current_cov[:, transcript_id]) - n_exons.append(len(transcript['exons'])) + n_exons.append(len(transcript["exons"])) cov = pd.DataFrame(cov, columns=self.samples) if groups is not None: cov = pd.DataFrame({grn: cov[grp].sum(1) for grn, grp in groups.items()}) if isinstance(bins, int): - bins = np.linspace(x_range[0] - .5, x_range[1] - .5, bins + 1) + bins = np.linspace(x_range[0] - 0.5, x_range[1] - 0.5, bins + 1) cov[cov < min_coverage] = 0 if not weight_by_coverage: cov[cov > 0] = 1 - counts = pd.DataFrame({gn: np.histogram(n_exons, weights=g_cov, bins=bins)[0] for gn, g_cov in cov.items()}) + counts = pd.DataFrame( + { + gn: np.histogram(n_exons, weights=g_cov, bins=bins)[0] + for gn, g_cov in cov.items() + } + ) if add_reference: - ref_n_exons = [len(transcript['exons']) for _, _, transcript in self.iter_ref_transcripts(**ref_filter)] - counts['reference'] = np.histogram(ref_n_exons, bins=bins)[0] - bin_df = pd.DataFrame({'from': bins[:-1], 'to': bins[1:]}) - sub = f'counting transcripts covered by >= {min_coverage} reads' - if 'query' in tr_filter: + ref_n_exons = [ + len(transcript["exons"]) + for _, _, transcript in self.iter_ref_transcripts(**ref_filter) + ] + counts["reference"] = np.histogram(ref_n_exons, bins=bins)[0] + bin_df = pd.DataFrame({"from": bins[:-1], "to": bins[1:]}) + sub = f"counting transcripts covered by >= {min_coverage} reads" + if "query" in tr_filter: sub += f', filter query: {tr_filter["query"]}' - params = dict(yscale='log', title='exons per transcript\n' + sub, xlabel='number of exons per transcript') - return pd.concat([bin_df, counts], axis=1).set_index(['from', 'to']), params - - -def downstream_a_hist(self, groups=None, add_reference=False, bins=30, x_range=(0, 1), weight_by_coverage=True, min_coverage=2, transcript_filter={}, ref_filter={}): - '''Retrieves the distribution of downstream adenosine content. + params = dict( + yscale="log", + title="exons per transcript\n" + sub, + xlabel="number of exons per transcript", + ) + return pd.concat([bin_df, counts], axis=1).set_index(["from", "to"]), params + + +def downstream_a_hist( + self, + groups=None, + add_reference=False, + bins=30, + x_range=(0, 1), + weight_by_coverage=True, + min_coverage=2, + transcript_filter=None, + ref_filter=None, +): + """Retrieves the distribution of downstream adenosine content. High downstream adenosine content is indicative for internal priming. @@ -758,7 +1201,13 @@ def downstream_a_hist(self, groups=None, add_reference=False, bins=30, x_range=( :param min_coverage: Threshold to ignore poorly covered transcripts. :param tr_filter: Filter dict, that is passed to self.iter_transcripts(). :param ref_filter: Filter dict, that is passed to self.iter_ref_transcripts() (relevant only if add_reference=True). - :return: Table with downstream adenosine content distribution, and suggested parameters for isotools.plots.plot_distr().''' + :return: Table with downstream adenosine content distribution, and suggested parameters for isotools.plots.plot_distr(). + """ + if transcript_filter is None: + transcript_filter = {} + if ref_filter is None: + ref_filter = {} + acontent = [] cov = [] current = None @@ -768,7 +1217,7 @@ def downstream_a_hist(self, groups=None, add_reference=False, bins=30, x_range=( current_cov = gene.coverage cov.append(current_cov[:, transcript_id]) try: - acontent.append(transcript['downstream_A_content']) + acontent.append(transcript["downstream_A_content"]) except KeyError: acontent.append(-1) cov = pd.DataFrame(cov, columns=self.samples) @@ -779,17 +1228,37 @@ def downstream_a_hist(self, groups=None, add_reference=False, bins=30, x_range=( cov[cov < min_coverage] = 0 if not weight_by_coverage: cov[cov > 0] = 1 - counts = pd.DataFrame({group_name: np.histogram(acontent, weights=group_cov, bins=bins)[0] for group_name, group_cov in cov.items()}) + counts = pd.DataFrame( + { + group_name: np.histogram(acontent, weights=group_cov, bins=bins)[0] + for group_name, group_cov in cov.items() + } + ) if add_reference: - ref_acontent = [transcript['downstream_A_content'] for _, _, transcript in self.iter_ref_transcripts(**ref_filter) if 'downstream_A_content' in transcript] - counts['reference'] = np.histogram(ref_acontent, bins=bins)[0] - bin_df = pd.DataFrame({'from': bins[:-1], 'to': bins[1:]}) - params = dict(title='downstream genomic A content', xlabel='fraction of A downstream the transcript') - return pd.concat([bin_df, counts], axis=1).set_index(['from', 'to']), params - - -def direct_repeat_hist(self, groups=None, bins=10, x_range=(0, 10), weight_by_coverage=True, min_coverage=2, tr_filter={}): - '''Retrieves the distribution direct repeat length at splice junctions. + ref_acontent = [ + transcript["downstream_A_content"] + for _, _, transcript in self.iter_ref_transcripts(**ref_filter) + if "downstream_A_content" in transcript + ] + counts["reference"] = np.histogram(ref_acontent, bins=bins)[0] + bin_df = pd.DataFrame({"from": bins[:-1], "to": bins[1:]}) + params = dict( + title="downstream genomic A content", + xlabel="fraction of A downstream the transcript", + ) + return pd.concat([bin_df, counts], axis=1).set_index(["from", "to"]), params + + +def direct_repeat_hist( + self, + groups=None, + bins=10, + x_range=(0, 10), + weight_by_coverage=True, + min_coverage=2, + tr_filter=None, +): + """Retrieves the distribution direct repeat length at splice junctions. Direct repeats are indicative for reverse transcriptase template switching. @@ -799,40 +1268,77 @@ def direct_repeat_hist(self, groups=None, bins=10, x_range=(0, 10), weight_by_co :param weight_by_coverage: If True, each transcript is weighted by the coverage. :param min_coverage: Threshold to ignore poorly covered transcripts. :param tr_filter: Filter dict, that is passed to self.iter_transcripts(). - :return: Table with direct repeat length distribution, and suggested parameters for isotools.plots.plot_distr().''' + :return: Table with direct repeat length distribution, and suggested parameters for isotools.plots.plot_distr(). + """ + if tr_filter is None: + tr_filter = {} + # find the direct repeat length distribution in FSM transcripts and putative RTTS # putative RTTS are identified by introns where both splice sites are novel but within annotated exons # TODO: actually no need to check annotation, could simply use filter flags (or the definition from the filter flags, which should be faster) - rl = {cat: [] for cat in ('known', 'novel canonical', 'novel noncanonical')} + rl = {cat: [] for cat in ("known", "novel canonical", "novel noncanonical")} for gene, transcript_id, transcript in self.iter_transcripts(**tr_filter): - if 'annotation' in transcript and transcript['annotation'][0] == 0: # e.g. FSM - rl['known'].extend((drl, gene.coverage[:, transcript_id]) for drl in transcript['direct_repeat_len']) - elif gene.is_annotated and 'novel_splice_sites' in transcript: - novel_junction = [i // 2 for i in transcript['novel_splice_sites'] if i % 2 == 0 and i + 1 in transcript['novel_splice_sites']] - nc = {v[0] for v in transcript.get('noncanonical_splicing', [])} - rl['novel noncanonical'].extend((transcript['direct_repeat_len'][sj], gene.coverage[:, transcript_id]) for sj in novel_junction if sj in nc) - rl['novel canonical'].extend((transcript['direct_repeat_len'][sj], gene.coverage[:, transcript_id]) for sj in novel_junction if sj not in nc) - - rl_cov = {cat: pd.DataFrame((v[1] for v in rl[cat]), columns=self.samples) for cat in rl} + if "annotation" in transcript and transcript["annotation"][0] == 0: # e.g. FSM + rl["known"].extend( + (drl, gene.coverage[:, transcript_id]) + for drl in transcript["direct_repeat_len"] + ) + elif gene.is_annotated and "novel_splice_sites" in transcript: + novel_junction = [ + i // 2 + for i in transcript["novel_splice_sites"] + if i % 2 == 0 and i + 1 in transcript["novel_splice_sites"] + ] + nc = {v[0] for v in transcript.get("noncanonical_splicing", [])} + rl["novel noncanonical"].extend( + (transcript["direct_repeat_len"][sj], gene.coverage[:, transcript_id]) + for sj in novel_junction + if sj in nc + ) + rl["novel canonical"].extend( + (transcript["direct_repeat_len"][sj], gene.coverage[:, transcript_id]) + for sj in novel_junction + if sj not in nc + ) + + rl_cov = { + cat: pd.DataFrame((v[1] for v in rl[cat]), columns=self.samples) for cat in rl + } if groups is not None: - rl_cov = {cat: pd.DataFrame({grn: rl_cov[cat][grp].sum(1) for grn, grp in groups.items()}) for cat in rl_cov} + rl_cov = { + cat: pd.DataFrame( + {grn: rl_cov[cat][grp].sum(1) for grn, grp in groups.items()} + ) + for cat in rl_cov + } for cov_df in rl_cov.values(): cov_df[cov_df < min_coverage] = 0 if not weight_by_coverage: cov_df[cov_df > 0] = 1 if isinstance(bins, int): - bins = np.linspace(x_range[0] - .5, x_range[1] - .5, bins + 1) - counts = pd.DataFrame({f'{sample} {cat}': np.histogram([val[0] for val in rl_list], weights=rl_cov[cat][sample], bins=bins)[ - 0] for cat, rl_list in rl.items() for sample in (self.samples if groups is None else groups)}) - - bin_df = pd.DataFrame({'from': bins[:-1], 'to': bins[1:]}) - params = dict(title='direct repeat length', xlabel='length of direct repeats at splice junctons', ylabel='# transcripts') - - return pd.concat([bin_df, counts], axis=1).set_index(['from', 'to']), params - - -def rarefaction(self, groups=None, fractions=20, min_coverage=2, tr_filter={}): - '''Rarefaction analysis + bins = np.linspace(x_range[0] - 0.5, x_range[1] - 0.5, bins + 1) + counts = pd.DataFrame( + { + f"{sample} {cat}": np.histogram( + [val[0] for val in rl_list], weights=rl_cov[cat][sample], bins=bins + )[0] + for cat, rl_list in rl.items() + for sample in (self.samples if groups is None else groups) + } + ) + + bin_df = pd.DataFrame({"from": bins[:-1], "to": bins[1:]}) + params = dict( + title="direct repeat length", + xlabel="length of direct repeats at splice junctons", + ylabel="# transcripts", + ) + + return pd.concat([bin_df, counts], axis=1).set_index(["from", "to"]), params + + +def rarefaction(self, groups=None, fractions=20, min_coverage=2, tr_filter=None): + """Rarefaction analysis Reads are sub-sampled according to the provided fractions, to estimate saturation of the transcriptome. @@ -843,7 +1349,10 @@ def rarefaction(self, groups=None, fractions=20, min_coverage=2, tr_filter={}): :param tr_filter: Filter dict, that is passed to self.iter_transcripts(). :return: Tuple with: 1) Data frame containing the number of discovered transcripts, for each sub-sampling fraction and each sample / sample group. - 2) Dict with total number of reads for each group. ''' + 2) Dict with total number of reads for each group. + """ + if tr_filter is None: + tr_filter = {} cov = [] current = None @@ -855,20 +1364,37 @@ def rarefaction(self, groups=None, fractions=20, min_coverage=2, tr_filter={}): current_cov = gene.coverage cov.append(current_cov[:, transcript_id]) cov = pd.DataFrame(cov, columns=self.samples) - total = dict(self.sample_table.set_index('name').nonchimeric_reads) + total = dict(self.sample_table.set_index("name").nonchimeric_reads) if groups is not None: cov = pd.DataFrame({grn: cov[grp].sum(1) for grn, grp in groups.items()}) - total = {groupname: sum(n for sample, n in total.items() if sample in group) for groupname, group in groups.items()} + total = { + groupname: sum(n for sample, n in total.items() if sample in group) + for groupname, group in groups.items() + } curves = {} for sample in cov: - curves[sample] = [(np.random.binomial(n=cov[sample], p=th) >= min_coverage).sum() for th in fractions] + curves[sample] = [ + (np.random.binomial(n=cov[sample], p=th) >= min_coverage).sum() + for th in fractions + ] return pd.DataFrame(curves, index=fractions), total -def coordination_test(self: 'Transcriptome', samples=None, test: Literal['fisher', 'chi2'] = "fisher", min_dist_AB=1, min_dist_events=1, min_total=100, min_alt_fraction=.1, - events_dict=None, event_type: list[ASEType] = ("ES", "5AS", "3AS", "IR", "ME"), padj_method="fdr_bh", - transcript_filter: Optional[str] = None, **kwargs) -> pd.DataFrame: - '''Performs gene_coordination_test on all genes. +def coordination_test( + self: "Transcriptome", + samples=None, + test: Literal["fisher", "chi2"] = "fisher", + min_dist_AB=1, + min_dist_events=1, + min_total=100, + min_alt_fraction=0.1, + events_dict=None, + event_type: list[ASEType] = ("ES", "5AS", "3AS", "IR", "ME"), + padj_method="fdr_bh", + transcript_filter: Optional[str] = None, + **kwargs, +) -> pd.DataFrame: + """Performs gene_coordination_test on all genes. :param samples: Specify the samples that should be considered in the test. The samples can be provided either as a single group name, a list of sample names, or a list of sample indices. @@ -893,7 +1419,7 @@ def coordination_test(self: 'Transcriptome', samples=None, test: Literal['fisher test is used), the log2 OR, the gene Id, the gene name, the type of the first ASE, the type of the second ASE, the starting coordinate of the first ASE, the ending coordinate of the first ASE, the starting coordinate of the second ASE, the ending coordinate of the second ASE, - and the four entries of the contingency table.''' + and the four entries of the contingency table.""" test_res = [] @@ -918,16 +1444,43 @@ def coordination_test(self: 'Transcriptome', samples=None, test: Literal['fisher test_res.extend(next_test_res) except Exception as e: - logger.error(f"\nError encountered on {print(gene)} {gene.id} : {gene.name}.") + logger.error( + f"\nError encountered on {print(gene)} {gene.id}: {gene.name}." + ) raise e - col_names = ("gene_id", "gene_name", "strand", "eventA_type", "eventB_type", "eventA_start", "eventA_end", - "eventB_start", "eventB_end", "pvalue", "statistic", "log2OR", "dcPSI_AB", "dcPSI_BA", "priA_priB", "priA_altB", "altA_priB", - "altA_altB", "priA_priB_transcript_ids", "priA_altB_transcript_ids", "altA_priB_transcript_ids", "altA_altB_transcript_ids") + col_names = ( + "gene_id", + "gene_name", + "strand", + "eventA_type", + "eventB_type", + "eventA_start", + "eventA_end", + "eventB_start", + "eventB_end", + "pvalue", + "statistic", + "log2OR", + "dcPSI_AB", + "dcPSI_BA", + "priA_priB", + "priA_altB", + "altA_priB", + "altA_altB", + "priA_priB_transcript_ids", + "priA_altB_transcript_ids", + "altA_priB_transcript_ids", + "altA_altB_transcript_ids", + ) res = pd.DataFrame(test_res, columns=col_names) - adj_p_value = multi.multipletests(res.pvalue, method=padj_method)[1] if len(res.pvalue) > 0 else [] + adj_p_value = ( + multi.multipletests(res.pvalue, method=padj_method)[1] + if len(res.pvalue) > 0 + else [] + ) res.insert(10, "padj", adj_p_value) diff --git a/src/isotools/_utils.py b/src/isotools/_utils.py index 258508c..c6a1fce 100644 --- a/src/isotools/_utils.py +++ b/src/isotools/_utils.py @@ -9,46 +9,61 @@ from scipy.stats import chi2_contingency, fisher_exact import math from typing import Literal, TypeAlias, TYPE_CHECKING -from intervaltree import Interval, IntervalTree +from intervaltree import IntervalTree if TYPE_CHECKING: from isotools.transcriptome import Transcriptome from .splice_graph import SegmentGraph -ASEType: TypeAlias = Literal['ES', '3AS', '5AS', 'IR', 'ME', 'TSS', 'PAS'] +ASEType: TypeAlias = Literal["ES", "3AS", "5AS", "IR", "ME", "TSS", "PAS"] ASEvent: TypeAlias = tuple[list[int], list[int], int, int, ASEType] -''' +""" In order: - transcripts supporting the primary event (the longer path for the basic event types) - transcripts supporting the alternative event (the shorter path for the basic event types) - node A id - node B id - event type -''' +""" # from Kozak et al, NAR, 1987 -kozak = np.array([[23, 35, 23, 19], [26, 35, 21, 18], [25, 35, 22, 18], [23, 26, 33, 18], [19, 39, 23, 19], [23, 37, 20, 20], [ - 17, 19, 44, 20], [18, 39, 23, 20], [25, 53, 15, 7], [61, 2, 36, 1], [27, 49, 13, 11], [15, 55, 21, 9], [23, 16, 46, 15]]) -bg = kozak.sum(0)/kozak.sum() +kozak = np.array( + [ + [23, 35, 23, 19], + [26, 35, 21, 18], + [25, 35, 22, 18], + [23, 26, 33, 18], + [19, 39, 23, 19], + [23, 37, 20, 20], + [17, 19, 44, 20], + [18, 39, 23, 20], + [25, 53, 15, 7], + [61, 2, 36, 1], + [27, 49, 13, 11], + [15, 55, 21, 9], + [23, 16, 46, 15], + ] +) +bg = kozak.sum(0) / kozak.sum() kozak.sum(1) # check they sum up to 100% -kozak_weights = np.log2(kozak/100/bg) +kozak_weights = np.log2(kozak / 100 / bg) kozak_weights = np.c_[kozak_weights, np.zeros(kozak_weights.shape[0])] -kozak_pos = list(range(-12, 0))+[3] -DEFAULT_KOZAK_PWM = pd.DataFrame(kozak_weights.T, columns=kozak_pos, index=[*'ACGTN']) -logger = logging.getLogger('isotools') +kozak_pos = list(range(-12, 0)) + [3] +DEFAULT_KOZAK_PWM = pd.DataFrame(kozak_weights.T, columns=kozak_pos, index=[*"ACGTN"]) +logger = logging.getLogger("isotools") -cigar = 'MIDNSHP=XB' +cigar = "MIDNSHP=XB" cigar_lup = {c: i for i, c in enumerate(cigar)} -compl = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'} +compl = {"A": "T", "T": "A", "C": "G", "G": "C"} def rc(seq): - '''reverse complement of seq + """reverse complement of seq :param seq: sequence - :return: reverse complement of seq''' - return ''.join(reversed([compl[c] if c in compl else 'N' for c in seq])) + :return: reverse complement of seq""" + return "".join(reversed([compl[c] if c in compl else "N" for c in seq])) def get_error_rate(bam_fn, n=1000): @@ -58,52 +73,63 @@ def get_error_rate(bam_fn, n=1000): if n is None: stats = align.get_index_statistics() n = sum([s.mapped for s in stats]) - with tqdm(total=n, unit=' reads') as pbar: + with tqdm(total=n, unit=" reads") as pbar: for i, read in enumerate(align): total_len += len(read.query_qualities) - qual += sum([10**(-q/10) for q in read.query_qualities]) + qual += sum([10 ** (-q / 10) for q in read.query_qualities]) pbar.update(1) - if i+1 >= n: + if i + 1 >= n: break - return (qual/total_len)*100 + return (qual / total_len) * 100 -def basequal_hist(bam_fn, qual_bins=10**(np.linspace(-7, 0, 30)), len_bins=None, n=10000): - '''calculates base quality statistics for a bam file: +def basequal_hist(bam_fn, qual_bins=None, len_bins=None, n=10000): + """calculates base quality statistics for a bam file: :param bam_fn: path to bam file :param qual_bins: list of quality thresholds for binning :param len_bins: list of read length thresholds for binning :param n: number of reads to use for statistics - :return: pandas Series or DataFrame with base quality statistics''' + :return: pandas Series or DataFrame with base quality statistics + """ + if qual_bins is None: + qual_bins = 10 ** (np.linspace(-7, 0, 30)) - n_len_bins = 1 if len_bins is None else len(len_bins)+1 - qual = np.zeros((len(qual_bins)+1, n_len_bins), dtype=int) + n_len_bins = 1 if len_bins is None else len(len_bins) + 1 + qual = np.zeros((len(qual_bins) + 1, n_len_bins), dtype=int) len_i = 0 i = 0 with AlignmentFile(bam_fn, "rb") as align: if n is None: stats = align.get_index_statistics() n = sum([s.mapped for s in stats]) - with tqdm(total=n, unit=' reads') as pbar: + with tqdm(total=n, unit=" reads") as pbar: for read in align: if read.query_qualities is None: pbar.update(1) continue readl = len(read.query_qualities) if len_bins is not None: - len_i = next((i for i, th in enumerate(len_bins) if readl < th), len(len_bins)) - error_rate = sum([10**(-q/10) for q in read.query_qualities])/readl*100 - q_i = next((i for i, th in enumerate(qual_bins) if error_rate < th), len(qual_bins)) + len_i = next( + (i for i, th in enumerate(len_bins) if readl < th), + len(len_bins), + ) + error_rate = ( + sum([10 ** (-q / 10) for q in read.query_qualities]) / readl * 100 + ) + q_i = next( + (i for i, th in enumerate(qual_bins) if error_rate < th), + len(qual_bins), + ) qual[q_i, len_i] += 1 pbar.update(1) i += 1 if i >= n: break - idx = [f'<{th:.2E} %' for th in qual_bins]+[f'>={qual_bins[-1]:.2E} %'] + idx = [f"<{th:.2E} %" for th in qual_bins] + [f">={qual_bins[-1]:.2E} %"] if len_bins is None: return pd.Series(qual[:, 0], index=idx) - col = [f'<{th/1000:.1f} kb' for th in len_bins]+[f'>={len_bins[-1]/1000:.1f} kb'] + col = [f"<{th/1000:.1f} kb" for th in len_bins] + [f">={len_bins[-1]/1000:.1f} kb"] return pd.DataFrame(qual, index=idx, columns=col) @@ -115,16 +141,16 @@ def pairwise(iterable): # e.g. usefull for enumerating introns def cigar_string2tuples(cigarstring): - '''converts cigar string to tuples ((operator_id, length), ...) + """converts cigar string to tuples ((operator_id, length), ...) :param cigarstring: cigar string - :return: tuple of tuples''' + :return: tuple of tuples""" - res = re.findall(f'(\\d+)([{cigar}]+)', cigarstring) + res = re.findall(f"(\\d+)([{cigar}]+)", cigarstring) return tuple((cigar_lup[c], int(n)) for n, c in res) def junctions_from_cigar(cigartuples, offset): - 'returns the exon positions' + "returns the exon positions" exons = list([[offset, offset]]) for cigar in cigartuples: # N -> Splice junction @@ -144,29 +170,29 @@ def junctions_from_cigar(cigartuples, offset): return exons -def is_same_gene(tr1, tr2, spj_iou_th=0, reg_iou_th=.5): - 'Checks whether tr1 and tr2 are the same gene by calculating intersection over union of the intersects' +def is_same_gene(tr1, tr2, spj_iou_th=0, reg_iou_th=0.5): + "Checks whether tr1 and tr2 are the same gene by calculating intersection over union of the intersects" # current default definition of "same gene": at least one shared splice site # or more than 50% exonic overlap spj_i, reg_i = get_intersects(tr1, tr2) - total_spj = (len(tr1)+len(tr2)-2)*2 - spj_iou = spj_i/(total_spj-spj_i) if total_spj > 0 else 0 + total_spj = (len(tr1) + len(tr2) - 2) * 2 + spj_iou = spj_i / (total_spj - spj_i) if total_spj > 0 else 0 if spj_iou > spj_iou_th: return True - total_len = sum([e[1]-e[0] for e in tr2+tr1]) - reg_iou = reg_i/(total_len-reg_i) + total_len = sum([e[1] - e[0] for e in tr2 + tr1]) + reg_iou = reg_i / (total_len - reg_i) if reg_iou > reg_iou_th: return True return False def splice_identical(exon_list1, exon_list2, strictness=math.inf): - ''' + """ Check whether two transcripts are identical in terms of splice sites. :param exon_list1: transcript 1 as a list of tuples for each exon :param exon_list2: transcript 2 as a list of tuples for each exon :param strictness: Number of bp that are allowed to differ for transcription start and end sites to be still considered identical. - ''' + """ # all splice sites are equal # different number of exons if len(exon_list1) != len(exon_list2): @@ -175,7 +201,10 @@ def splice_identical(exon_list1, exon_list2, strictness=math.inf): if len(exon_list1) == 1 and has_overlap(exon_list1[0], exon_list2[0]): return True # Check start of first and end of last exon - if abs(exon_list1[0][0] - exon_list2[0][0]) > strictness or abs(exon_list1[-1][1] - exon_list2[-1][1]) > strictness: + if ( + abs(exon_list1[0][0] - exon_list2[0][0]) > strictness + or abs(exon_list1[-1][1] - exon_list2[-1][1]) > strictness + ): return False # check end of first and and start of last exon if exon_list1[0][1] != exon_list2[0][1] or exon_list1[-1][0] != exon_list2[-1][0]: @@ -188,27 +217,48 @@ def splice_identical(exon_list1, exon_list2, strictness=math.inf): def kozak_score(sequence, pos, pwm=DEFAULT_KOZAK_PWM): - return sum(pwm.loc[sequence[pos+i], i] for i in pwm.columns if pos+i >= 0 and pos+i < len(sequence)) + return sum( + pwm.loc[sequence[pos + i], i] + for i in pwm.columns + if pos + i >= 0 and pos + i < len(sequence) + ) + + +def find_orfs(sequence, start_codons=None, stop_codons=None, ref_cds=None): + """Find all open reading frames on the forward strand of the sequence. + :param sequence: DNA sequence to search for ORFs. + :param start_codons: List of start codons (default: ["ATG"]). + :param stop_codons: List of stop codons (default: ["TAA", "TAG", "TGA"]). + :param ref_cds: Dictionary of reference CDS (default: {}). + + :return: List of ORFs as tuples, containing a 7-tuple with start and stop position, reading frame (0,1 or 2), start and stop codon sequence, + number of upstream start codons, and the ids of the reference transcripts with matching CDS initialization. + """ + if start_codons is None: + start_codons = ["ATG"] + if stop_codons is None: + stop_codons = ["TAA", "TAG", "TGA"] + if ref_cds is None: + ref_cds = {} - -def find_orfs(sequence, start_codons=["ATG"], stop_codons=['TAA', 'TAG', 'TGA'], ref_cds={}): - '''Find all open reading frames on the forward strand of the sequence. - Return a 7-tuple with start and stop position, reading frame (0,1 or 2), start and stop codon sequence, - number of upstream start codons, and the ids of the reference transcripts with matching CDS initialization.''' orf = [] starts = [[], [], []] stops = [[], [], []] for init, ref_ids in ref_cds.items(): - starts[init % 3].append((init, sequence[init:(init+3)], ref_ids)) + starts[init % 3].append((init, sequence[init : (init + 3)], ref_ids)) for match in re.finditer("|".join(start_codons), sequence): if match.start() not in ref_cds: - starts[match.start() % 3].append((match.start(), match.group(), None)) # position and codon + starts[match.start() % 3].append( + (match.start(), match.group(), None) + ) # position and codon for match in re.finditer("|".join(stop_codons), sequence): stops[match.start() % 3].append((match.end(), match.group())) for frame in range(3): stop, stop_codon = (0, None) for start, start_codon, ref_ids in starts[frame]: - if start < stop and ref_ids is None: # inframe start within the previous ORF + if ( + start < stop and ref_ids is None + ): # inframe start within the previous ORF continue try: stop, stop_codon = next(s for s in sorted(stops[frame]) if s[0] > start) @@ -250,11 +300,11 @@ def get_intersects(tr1, tr2): if has_overlap(tr1_exon, tr2_exon): if tr1_exon[0] == tr2_exon[0] and i > 0 and j > 0: sjintersect += 1 - if tr1_exon[1] == tr2_exon[1] and i < len(tr1)-1 and j < len(tr2)-1: + if tr1_exon[1] == tr2_exon[1] and i < len(tr1) - 1 and j < len(tr2) - 1: sjintersect += 1 i_end = min(tr1_exon[1], tr2_exon[1]) i_start = max(tr1_exon[0], tr2_exon[0]) - intersect += (i_end - i_start) + intersect += i_end - i_start if tr1_exon[1] <= tr2_exon[0]: i, tr1_exon = next(tr1_enum) else: @@ -263,12 +313,15 @@ def get_intersects(tr1, tr2): return sjintersect, intersect -def _filter_function(expression, context_filters = {}): - ''' +def _filter_function(expression, context_filters=None): + """ converts a string e.g. "all(x[0]/x[1]>3)" into a function if context_filters is provided, filter tags will be recursively replaced with their expression - ''' - assert isinstance(expression, str), 'expression should be a string' + """ + if context_filters is None: + context_filters = {} + + assert isinstance(expression, str), "expression should be a string" # extract argument names used_filters = [] depth = 0 @@ -278,26 +331,42 @@ def _filter_function(expression, context_filters = {}): if filter in context_filters: # brackets around expression prevent unintended "mixing" of neighboring filters expression = expression.replace(filter, f"({context_filters[filter]})") - f = eval(f'lambda: {expression}') + f = eval(f"lambda: {expression}") args = [n for n in f.__code__.co_names if n not in dir(builtins)] used_filters = [arg for arg in args if arg in context_filters] depth += 1 if len(used_filters) == 0: break if depth > 10: - raise ValueError(f'Filter expression evaluation max depth reached. Expression `{original_expression}` was evaluated to `{expression}`') + raise ValueError( + f"Filter expression evaluation max depth reached. Expression `{original_expression}` was evaluated to `{expression}`" + ) # potential issue: gene.coverage gets detected as ["gene", "coverage"], e.g. coverage is added. Probably not causing trubble - return eval(f'lambda {",".join([arg+"=None" for arg in args]+["**kwargs"])}: bool({expression})\n', {}, {}), args + return ( + eval( + f'lambda {",".join([arg+"=None" for arg in args]+["**kwargs"])}: bool({expression})\n', + {}, + {}, + ), + args, + ) def _interval_dist(a: tuple[int, int], b: tuple[int, int]): - '''compute the distance between two intervals a and b.''' - return max([a[0], b[0]])-min([a[1], b[1]]) - - -def _filter_event(coverage, event: ASEvent, segment_graph: 'SegmentGraph', min_total=100, min_alt_fraction=.1, min_dist_AB=0): - ''' + """compute the distance between two intervals a and b.""" + return max([a[0], b[0]]) - min([a[1], b[1]]) + + +def _filter_event( + coverage, + event: ASEvent, + segment_graph: "SegmentGraph", + min_total=100, + min_alt_fraction=0.1, + min_dist_AB=0, +): + """ return True if the event satisfies the filter conditions and False otherwise :param coverage: 1D array of counts per transcript @@ -307,9 +376,9 @@ def _filter_event(coverage, event: ASEvent, segment_graph: 'SegmentGraph', min_t :param min_alt_fraction: The minimum fraction of read supporting the alternative :type min_alt_frction: float :param min_dist_AB: Minimum distance (in nucleotides) between node A and B in an event - ''' + """ - tr_IDs = event[0]+event[1] + tr_IDs = event[0] + event[1] tot_cov = coverage[tr_IDs].sum() if tot_cov < min_total: @@ -317,7 +386,7 @@ def _filter_event(coverage, event: ASEvent, segment_graph: 'SegmentGraph', min_t pri_cov = coverage[event[0]].sum() alt_cov = coverage[event[1]].sum() - frac = min(pri_cov, alt_cov)/tot_cov + frac = min(pri_cov, alt_cov) / tot_cov if frac < min_alt_fraction: return False @@ -330,8 +399,12 @@ def _filter_event(coverage, event: ASEvent, segment_graph: 'SegmentGraph', min_t def _get_exonic_region(transcripts): - exon_starts = iter(sorted([e[0] for transcript in transcripts for e in transcript['exons']])) - exon_ends = iter(sorted([e[1] for transcript in transcripts for e in transcript['exons']])) + exon_starts = iter( + sorted([e[0] for transcript in transcripts for e in transcript["exons"]]) + ) + exon_ends = iter( + sorted([e[1] for transcript in transcripts for e in transcript["exons"]]) + ) exon_region = [[next(exon_starts), next(exon_ends)]] for next_start in exon_starts: if next_start <= exon_region[-1][1]: @@ -342,12 +415,12 @@ def _get_exonic_region(transcripts): def _get_overlap(exons, transcripts): - '''Compute the exonic overlap of a new transcript with the segment graph. + """Compute the exonic overlap of a new transcript with the segment graph. Avoids the computation of segment graph, which provides the same functionality. :param exons: A list of exon tuples representing the new transcript :type exons: list - :return: boolean array indicating whether the splice site is contained or not''' + :return: boolean array indicating whether the splice site is contained or not""" if not transcripts: return 0 # 1) get exononic regions in transcripts @@ -363,7 +436,7 @@ def _get_overlap(exons, transcripts): while exon_region[i][0] < exon[1]: i_end = min(exon[1], exon_region[i][1]) i_start = max(exon[0], exon_region[i][0]) - ol += (i_end - i_start) + ol += i_end - i_start if exon_region[i][1] > exon[1]: # might overlap with next exon break i += 1 @@ -373,12 +446,12 @@ def _get_overlap(exons, transcripts): def _find_splice_sites(splice_junctions, transcripts): - '''Checks whether the splice sites of a new transcript are present in the set of transcripts. + """Checks whether the splice sites of a new transcript are present in the set of transcripts. Avoids the computation of segment graph, which provides the same functionality. :param splice_junctions: A list of 2 tuples with the splice site positions :param transcripts: transcripts to scan - :return: boolean array indicating whether the splice site is contained or not''' + :return: boolean array indicating whether the splice site is contained or not""" sites = np.zeros((len(splice_junctions)) * 2, dtype=bool) # check exon ends @@ -388,9 +461,15 @@ def _find_splice_sites(splice_junctions, transcripts): splice_junction_starts.setdefault(splice_site[0], []).append(i) splice_junction_ends.setdefault(splice_site[1], []).append(i) - transcript_list = [iter(transcript['exons'][:-1]) for transcript in transcripts if len(transcript['exons']) > 1] + transcript_list = [ + iter(transcript["exons"][:-1]) + for transcript in transcripts + if len(transcript["exons"]) > 1 + ] current = [next(transcript) for transcript in transcript_list] - for splice_junction_start, idx in sorted(splice_junction_starts.items()): # splice junction starts, sorted by position + for splice_junction_start, idx in sorted( + splice_junction_starts.items() + ): # splice junction starts, sorted by position for j, transcript_iter in enumerate(transcript_list): try: while splice_junction_start > current[j][1]: @@ -402,9 +481,15 @@ def _find_splice_sites(splice_junctions, transcripts): except StopIteration: continue # check exon starts - transcript_list = [iter(transcript['exons'][1:]) for transcript in transcripts if len(transcript['exons']) > 1] + transcript_list = [ + iter(transcript["exons"][1:]) + for transcript in transcripts + if len(transcript["exons"]) > 1 + ] current = [next(transcript) for transcript in transcript_list] - for splice_junction_end, idx in sorted(splice_junction_ends.items()): # splice junction ends, sorted by position + for splice_junction_end, idx in sorted( + splice_junction_ends.items() + ): # splice junction ends, sorted by position for j, transcript_iter in enumerate(transcript_list): try: while splice_junction_end > current[j][0]: @@ -418,27 +503,47 @@ def _find_splice_sites(splice_junctions, transcripts): return sites -def precompute_events_dict(transcriptome: 'Transcriptome', event_type: list[ASEType] = ("ES", "5AS", "3AS", "IR", "ME"), - min_cov=100, region=None, query=None, progress_bar=True): - ''' +def precompute_events_dict( + transcriptome: "Transcriptome", + event_type: list[ASEType] = ("ES", "5AS", "3AS", "IR", "ME"), + min_cov=100, + region=None, + query=None, + progress_bar=True, +): + """ Precomputes the events_dict, i.e. a dictionary of splice bubbles. Each key is a gene and each value is the splice bubbles object corresponding to that gene. :param region: The region to be considered. Either a string "chr:start-end", or a tuple (chr,start,end). Start and end is optional. - ''' + """ events_dict = {} - for gene in transcriptome.iter_genes(region=region, query=query, progress_bar=progress_bar): + for gene in transcriptome.iter_genes( + region=region, query=query, progress_bar=progress_bar + ): sg = gene.segment_graph - events = [event for event in sg.find_splice_bubbles(types=event_type) if gene.coverage.sum(axis=0)[event[0]+event[1]].sum() >= min_cov] + events = [ + event + for event in sg.find_splice_bubbles(types=event_type) + if gene.coverage.sum(axis=0)[event[0] + event[1]].sum() >= min_cov + ] if events: events_dict[gene.id] = events return events_dict -def get_quantiles(pos: list[tuple[int, int]], percentile=[.5]): - '''provided a list of (positions,coverage) pairs, return the median position''' +def get_quantiles(pos: list[tuple[int, int]], percentile=None): + """Provided a list of (positions, coverage) pairs, return the median position. + + :param pos: List of tuples containing positions and coverage values. + :param percentile: List of percentiles to calculate (default: [0.5]). + :return: List of positions corresponding to the given percentiles. + """ + if percentile is None: + percentile = [0.5] + # percentile should be sorted, and between 0 and 1 total = sum(cov for _, cov in pos) n = 0 @@ -449,21 +554,21 @@ def get_quantiles(pos: list[tuple[int, int]], percentile=[.5]): result_list.append(p) if len(result_list) == len(percentile): return result_list - raise ValueError(f'cannot find {percentile[len(result_list)]} percentile of {pos}') + raise ValueError(f"cannot find {percentile[len(result_list)]} percentile of {pos}") def smooth(x, window_len=31): - '''smooth the data using a hanning window with requested size.''' + """smooth the data using a hanning window with requested size.""" # padding with mirrored - s = np.r_[x[window_len-1:0:-1], x, x[-2:-window_len-1:-1]] + s = np.r_[x[window_len - 1 : 0 : -1], x, x[-2 : -window_len - 1 : -1]] # print(len(s)) w = np.hanning(window_len) - y = np.convolve(w/w.sum(), s, mode='valid') - return y[int(window_len/2-(window_len+1) % 2):-int(window_len/2)] + y = np.convolve(w / w.sum(), s, mode="valid") + return y[int(window_len / 2 - (window_len + 1) % 2) : -int(window_len / 2)] def prepare_contingency_table(eventA: ASEvent, eventB: ASEvent, coverage): - ''' + """ Prepare the read counts and transcript id contingency tables for two events. Returns two 2x2 contingency tables, one with the read counts, one with the transcript events @@ -471,29 +576,33 @@ def prepare_contingency_table(eventA: ASEvent, eventB: ASEvent, coverage): :param eventA: First alternative splicing event obtained from .find_splice_bubbles() :param eventB: Second alternative splicing event obtained from .find_splice_bubbles() :param coverage: Read counts per transcript. - ''' + """ con_tab = np.zeros((2, 2), dtype=int) transcript_id_table = np.zeros((2, 2), dtype=object) for m, n in itertools.product(range(2), range(2)): - transcript_ids = sorted(set(eventA[m]) & set(eventB[n]), key=coverage.__getitem__, reverse=True) + transcript_ids = sorted( + set(eventA[m]) & set(eventB[n]), key=coverage.__getitem__, reverse=True + ) transcript_id_table[n, m] = transcript_ids con_tab[n, m] = coverage[transcript_ids].sum() return con_tab, transcript_id_table -def pairwise_event_test(con_tab, test: Literal['fisher', 'chi2'] = "fisher", pseudocount=.01): - ''' +def pairwise_event_test( + con_tab, test: Literal["fisher", "chi2"] = "fisher", pseudocount=0.01 +): + """ Performs an independence test on the contingency table and computes effect sizes. :param con_tab: contingency table with the read counts :param test: Test to be performed. One of ("chi2", "fisher") :type test: str - ''' - if test == 'chi2': + """ + if test == "chi2": test_fun = chi2_contingency - elif test == 'fisher': + elif test == "fisher": test_fun = fisher_exact else: raise (ValueError('test should be "chi2" or "fisher"')) @@ -522,59 +631,73 @@ def _corrected_log2OR(con_tab): con_tab_copy[n, m] = 10**-9 else: con_tab_copy[n, m] = con_tab[n, m] - log2OR = np.log2((con_tab_copy[0, 0]*con_tab_copy[1, 1])) - np.log2((con_tab_copy[0, 1]*con_tab_copy[1, 0])) + log2OR = np.log2((con_tab_copy[0, 0] * con_tab_copy[1, 1])) - np.log2( + (con_tab_copy[0, 1] * con_tab_copy[1, 0]) + ) return log2OR def dcPSI(con_tab): - '''delta conditional PSI of a coordinated event''' + """delta conditional PSI of a coordinated event""" # 1) dcPSI_AB= PSI(B | altA) - PSI(B) - dcPSI_AB = con_tab[1, 1]/con_tab[:, 1].sum()-con_tab[1, :].sum()/con_tab.sum(None) + dcPSI_AB = con_tab[1, 1] / con_tab[:, 1].sum() - con_tab[1, :].sum() / con_tab.sum( + None + ) # 2) dcPSI_BA= PSI(A | altB) - PSI(A) - dcPSI_BA = con_tab[1, 1]/con_tab[1, :].sum()-con_tab[:, 1].sum()/con_tab.sum(None) + dcPSI_BA = con_tab[1, 1] / con_tab[1, :].sum() - con_tab[:, 1].sum() / con_tab.sum( + None + ) return dcPSI_AB, dcPSI_BA def genomic_position(tr_pos, exons, reverse_strand): - tr_len = sum((e[1]-e[0]) for e in exons) - assert all(p <= tr_len for p in tr_pos), f'Requested positions {tr_pos} for transcript of length {tr_len}.' - if reverse_strand: - tr_pos = [tr_len-p for p in tr_pos] - tr_pos = sorted(set(tr_pos)) + tr_len = sum((e[1] - e[0]) for e in exons) + if not all(p <= tr_len for p in tr_pos): + raise ValueError( + f"One or more positions in {tr_pos} exceed the transcript length of {tr_len}." + ) + + tr_pos = sorted(set(tr_len - p for p in tr_pos) if reverse_strand else set(tr_pos)) + intron_len = 0 mapped_pos = [] i = 0 offset = exons[0][0] + for e1, e2 in pairwise(exons): - while offset+intron_len+tr_pos[i] < e1[1]: - mapped_pos.append(offset+intron_len+tr_pos[i]) + while offset + intron_len + tr_pos[i] < e1[1]: + mapped_pos.append(offset + intron_len + tr_pos[i]) i += 1 if i == len(tr_pos): break else: - intron_len += e2[0]-e1[1] + intron_len += e2[0] - e1[1] continue break else: - for i in range(i, len(tr_pos)): - mapped_pos.append(offset+intron_len+tr_pos[i]) - if reverse_strand: # get them back to the original - tr_pos = [tr_len-p for p in tr_pos] + for pos in tr_pos[i:]: + mapped_pos.append(offset + intron_len + pos) + + # reverse the positions back to the original if reverse_strand is True + if reverse_strand: + tr_pos = [tr_len - p for p in tr_pos] + return {p: mp for p, mp in zip(tr_pos, mapped_pos)} def cmp_dist(a, b, min_dist=3): - if a >= b+min_dist: + if a >= b + min_dist: return 1 - if b >= a+min_dist: + if b >= a + min_dist: return -1 return 0 # region gene structure variation -def structure_feature_cov(transcripts, samples, feature='TSS'): - ''' + +def structure_feature_cov(transcripts, samples, feature="TSS"): + """ :param transcripts: A list of transcript annotations of a gene obtained from isoseq[gene].transcripts. :param feature: 'EC', 'TSS', 'PAS'. :param samples: A list of sample names to specify the samples to be considered. @@ -582,24 +705,24 @@ def structure_feature_cov(transcripts, samples, feature='TSS'): 1) EC - exon_chain, query coverage matrix and exon positions, and return all the exon_chain and coverage. 2) TSS/PAS, query TSS_unified or PAS_unified coverage matrix, and return all the positions and coverage. - ''' + """ - assert feature in ['EC', 'TSS', 'PAS'], 'choose feature from EC, TSS, PAS' + assert feature in ["EC", "TSS", "PAS"], "choose feature from EC, TSS, PAS" cov = {} - if feature == 'EC': - field = 'coverage' + if feature == "EC": + field = "coverage" for transcript in transcripts: if transcript[field] is None: continue # Convert list of exons to tuple of tuples to make it hashable as a key of a dictionary - exon_chain = tuple(map(tuple, transcript['exons'])) + exon_chain = tuple(map(tuple, transcript["exons"])) for s, n in transcript[field].items(): if s in samples: cov[exon_chain] = cov.get(exon_chain, 0) + n else: - field = f'{feature}_unified' + field = f"{feature}_unified" for transcript in transcripts: if transcript[field] is None: continue @@ -621,34 +744,34 @@ def structure_feature_cov(transcripts, samples, feature='TSS'): def count_distinct_pos(pos_list, strict_pos=15): - ''' + """ :param pos_list: A list of TSS/PAS positions, sorted by their abundance descendingly (output from structure_feature_cov). :param strict_pos: Difference allowed between two positions when considering identical TSS/PAS. :return: How many distinct positions are there. - ''' + """ tree = IntervalTree() picked = 0 for pos in pos_list: if len(tree[pos]) == 0: - tree[pos-strict_pos:pos+strict_pos+1] = 1 + tree[pos - strict_pos : pos + strict_pos + 1] = 1 picked += 1 return picked def count_distinct_exon_chain(ec_list, strict_ec=0, strict_pos=15): - ''' + """ :param ec_list: A list of exon chains, sorted by their abundance descendingly (output from structure_feature_cov). :param strict_ec: Distance allowed between each position, except for the first/last, in two exon chains so that they can be considered as identical. :param strict_pos: Difference allowed between two positions when considering identical TSS/PAS. :return: How many distinct exon chains are there. - ''' + """ merged_idx = set() - for x in range(len(ec_list)-1): + for x in range(len(ec_list) - 1): if x in merged_idx: continue - for y in range(x+1, len(ec_list)): + for y in range(x + 1, len(ec_list)): if y in merged_idx: continue # if the number of exons is different, skip @@ -658,9 +781,16 @@ def count_distinct_exon_chain(ec_list, strict_ec=0, strict_pos=15): pos_in_x = [pos for exon in ec_list[x] for pos in exon] pos_in_y = [pos for exon in ec_list[y] for pos in exon] - pos_diff = [abs(m - n) for m,n in zip(pos_in_x, pos_in_y)] + pos_diff = [abs(m - n) for m, n in zip(pos_in_x, pos_in_y)] - if all(d <= strict_pos if (i == 0 or i == len(pos_diff)-1) else d <= strict_ec for i,d in enumerate(pos_diff)): + if all( + ( + d <= strict_pos + if (i == 0 or i == len(pos_diff) - 1) + else d <= strict_ec + ) + for i, d in enumerate(pos_diff) + ): # keep the one with higher coverage merged_idx.add(y) @@ -668,24 +798,34 @@ def count_distinct_exon_chain(ec_list, strict_ec=0, strict_pos=15): def str_var_triplet(transcripts, samples, strict_ec=0, strict_pos=15): - ''' + """ Quantify the structure variation of transcripts in a gene across specified samples. - + :param transcripts: A list of transcript annotations of a gene obtained from isoseq[gene].transcripts. :param samples: A list of sample names to specify the samples to be considered.. :param strict_ec: Distance allowed between each position, except for the first/last, in two exon chains so that they can be considered as identical. :param strict_pos: Difference allowed between two positions when considering identical TSS/PAS. :return (list): A triplet of numbers in the order of distinct TSS positions, exon chains, and PAS positions. - ''' - - _, ec_list = structure_feature_cov(transcripts=transcripts, samples=samples, feature='EC') - n_ec = count_distinct_exon_chain(ec_list=ec_list, strict_ec=strict_ec, strict_pos=strict_pos) - - _, tss_list = structure_feature_cov(transcripts=transcripts, samples=samples, feature='TSS') + """ + + _, ec_list = structure_feature_cov( + transcripts=transcripts, samples=samples, feature="EC" + ) + n_ec = count_distinct_exon_chain( + ec_list=ec_list, strict_ec=strict_ec, strict_pos=strict_pos + ) + + _, tss_list = structure_feature_cov( + transcripts=transcripts, samples=samples, feature="TSS" + ) n_tss = count_distinct_pos(pos_list=tss_list, strict_pos=strict_pos) - _, pas_list = structure_feature_cov(transcripts=transcripts, samples=samples, feature='PAS') + _, pas_list = structure_feature_cov( + transcripts=transcripts, samples=samples, feature="PAS" + ) n_pas = count_distinct_pos(pos_list=pas_list, strict_pos=strict_pos) return [n_tss, n_ec, n_pas] + + # endregion diff --git a/src/isotools/decorators.py b/src/isotools/decorators.py index 70b770f..1d35a98 100644 --- a/src/isotools/decorators.py +++ b/src/isotools/decorators.py @@ -1,35 +1,41 @@ import functools import logging -logger = logging.getLogger('isotools') +logger = logging.getLogger("isotools") # warn unsave functions def deprecated(func): """Warns about use of deprecated function""" + @functools.wraps(func) def wrapper_depreciated(*args, **kwargs): logger.warning(f"Calling deprecated function {func.__name__}") value = func(*args, **kwargs) return value + return wrapper_depreciated def experimental(func): """Informs about use of untested functionality""" + @functools.wraps(func) def wrapper_experimental(*args, **kwargs): logger.warning(f"Calling {func.__name__}, which is untested/experimental") value = func(*args, **kwargs) return value + return wrapper_experimental + # helpers for debugging def debug(func): """Print the function signature and return value""" + @functools.wraps(func) def wrapper_debug(*args, **kwargs): args_repr = [repr(a) for a in args] @@ -39,11 +45,13 @@ def wrapper_debug(*args, **kwargs): value = func(*args, **kwargs) logger.info(f"{func.__name__!r} returned {value!r}") return value + return wrapper_debug def traceback(func): """In case of exception, print the arguments""" + @functools.wraps(func) def wrapper_try(*args, **kwargs): try: @@ -55,4 +63,5 @@ def wrapper_try(*args, **kwargs): signature = ", ".join(args_repr + kwargs_repr) logger.info(f"Exception during call {func.__name__}({signature})") raise e + return wrapper_try diff --git a/src/isotools/domains.py b/src/isotools/domains.py index 0efb01f..b2aff4e 100644 --- a/src/isotools/domains.py +++ b/src/isotools/domains.py @@ -9,29 +9,39 @@ from ._utils import genomic_position, has_overlap -logger = logging.getLogger('isotools') +logger = logging.getLogger("isotools") def parse_hmmer_metadata(fn): with gzip.open(fn) as f: entries = [] entry = {} - while (True): + while True: try: line = next(f).decode().strip() except StopIteration: - metadata = pd.DataFrame(entries).set_index('AC') + metadata = pd.DataFrame(entries).set_index("AC") return metadata - if line == '//': + if line == "//": entries.append(entry.copy()) - elif line.startswith('#=GF'): + elif line.startswith("#=GF"): _, k, v = line.split(maxsplit=2) entry[k] = v -def add_domains_to_table(table, transcriptome, source='annotation', categories=None, id_col='gene_id', modes=['trA-trB', 'trB-trA'], - naming='id', overlap_only=False, insert_after=None, **filter_kwargs): - '''add domain annotation to table. +def add_domains_to_table( + table, + transcriptome, + source="annotation", + categories=None, + id_col="gene_id", + modes=None, + naming="id", + overlap_only=False, + insert_after=None, + **filter_kwargs, +): + """add domain annotation to table. :param table: A table, for which domains are derived. It should have at least one column with a gene id and one with a list of transcripts. @@ -47,7 +57,10 @@ def add_domains_to_table(table, transcriptome, source='annotation', categories=N If set "False", all domains of the transcripts are considered. :param insert_after: Define column after which the domains are inserted into the table, either by column name or index. By default, domain columns returned as separate DataFrame. - :param **filter_kwargs: additional keywords are passed to Gene.filter_transcripts, to restrict the transcripts to be considered.''' + :param **filter_kwargs: additional keywords are passed to Gene.filter_transcripts, to restrict the transcripts to be considered. + """ + if modes is None: + modes = ["trA-trB", "trB-trA"] # set operators: # set union: | @@ -55,21 +68,33 @@ def add_domains_to_table(table, transcriptome, source='annotation', categories=N # set intersection: & # check arguments - assert naming in ('id', 'name'), 'naming must be either "id" or "name".' - label_idx = 0 if naming == 'id' else 1 + assert naming in ("id", "name"), 'naming must be either "id" or "name".' + label_idx = 0 if naming == "id" else 1 assert id_col in table, f'Missing id column "{id_col}" in table.' # check the "modes": can they be evaluated? - tr_cols = {tr_col for mode in modes for tr_col in compile(mode, "", "eval").co_names} + tr_cols = { + tr_col + for mode in modes + for tr_col in compile(mode, "", "eval").co_names + } for mode in modes: # only set operations allowed, and should return a set - assert isinstance(eval(mode, {tr_col: set() for tr_col in tr_cols}), set), f'{mode} does not return a set' + assert isinstance( + eval(mode, {tr_col: set() for tr_col in tr_cols}), set + ), f"{mode} does not return a set" # Do they contain only table names? missing = [c for c in tr_cols if c not in table.columns] - assert len(missing) == 0, f'Missing transcript id columns in table: {", ".join(missing)}.' + assert ( + len(missing) == 0 + ), f'Missing transcript id columns in table: {", ".join(missing)}.' if insert_after is not None: if isinstance(insert_after, str): - assert insert_after in table.columns, 'cannot find column "{insert_after}" in table.' + assert ( + insert_after in table.columns + ), 'cannot find column "{insert_after}" in table.' insert_after = table.columns.get_loc(insert_after) - assert isinstance(insert_after, int), 'insert_after must be a column name or column index' + assert isinstance( + insert_after, int + ), "insert_after must be a column name or column index" domain_rows = {} for idx, row in table.iterrows(): @@ -80,9 +105,15 @@ def add_domains_to_table(table, transcriptome, source='annotation', categories=N domain_sets = {} for tr_col in tr_cols: domain_sets[tr_col] = set() - transcript_ids = set(row[tr_col]) & valid_transcripts if filter_kwargs else set(row[tr_col]) + transcript_ids = ( + set(row[tr_col]) & valid_transcripts + if filter_kwargs + else set(row[tr_col]) + ) for transcript_id in transcript_ids: - for dom in gene.transcripts[transcript_id].get('domain', {}).get(source, []): + for dom in ( + gene.transcripts[transcript_id].get("domain", {}).get(source, []) + ): if categories is not None and dom[2] not in categories: continue if overlap_only and not has_overlap(dom[4], (row.start, row.end)): @@ -92,15 +123,29 @@ def add_domains_to_table(table, transcriptome, source='annotation', categories=N # evaluate string in mode domains.append(eval(mode, domain_sets)) domain_rows[idx] = domains - domain_rows = pd.DataFrame.from_dict(domain_rows, orient='index', - columns=[f'{mode} { "overlap " if overlap_only else ""}domains' for mode in modes]) + domain_rows = pd.DataFrame.from_dict( + domain_rows, + orient="index", + columns=[ + f"{mode}{' overlap' if overlap_only else ''} domains" for mode in modes + ], + ) if insert_after is None: return domain_rows - return pd.concat([table.iloc[:, :insert_after+1], domain_rows, table.iloc[:, insert_after+1:]], axis=1) + return pd.concat( + [ + table.iloc[:, : insert_after + 1], + domain_rows, + table.iloc[:, insert_after + 1 :], + ], + axis=1, + ) -def import_hmmer_models(path, model_file="Pfam-A.hmm.gz", metadata_file="Pfam-A.hmm.dat.gz"): - '''Import the hmmer model and metadata. +def import_hmmer_models( + path, model_file="Pfam-A.hmm.gz", metadata_file="Pfam-A.hmm.dat.gz" +): + """Import the hmmer model and metadata. This function imports the hmmer Pfam models from "Pfam-A.hmm.gz" and metadata from "Pfam-A.hmm.dat.gz", which are available for download on the interpro website, at "https://www.ebi.ac.uk/interpro/download/Pfam/". @@ -108,24 +153,47 @@ def import_hmmer_models(path, model_file="Pfam-A.hmm.gz", metadata_file="Pfam-A. :param path: The path where model and metadata files are located. :param model_file: The filename of the model file. - :param model_file: The filename of the metadata file.''' + :param model_file: The filename of the metadata file.""" - metadata = parse_hmmer_metadata(f'{path}/{metadata_file}') - with pyhmmer.plan7.HMMFile((f'{path}/{model_file}')) as hmm_file: + metadata = parse_hmmer_metadata(f"{path}/{metadata_file}") + with pyhmmer.plan7.HMMFile((f"{path}/{model_file}")) as hmm_file: hmm_list = list(hmm_file) return metadata, hmm_list -def get_hmmer_sequences(transcriptome, genome_fn, aa_alphabet, query=True, ref_query=False, region=None, min_coverage=None, - max_coverage=None, gois=None, progress_bar=False): - '''Get protein sequences in binary hmmer format.''' +def get_hmmer_sequences( + transcriptome, + genome_fn, + aa_alphabet, + query=True, + ref_query=False, + region=None, + min_coverage=None, + max_coverage=None, + gois=None, + progress_bar=False, +): + """Get protein sequences in binary hmmer format.""" tr_ids = {} if query: - for gene, trids, _ in transcriptome.iter_transcripts(genewise=True, query=query, region=region, min_coverage=min_coverage, - max_coverage=max_coverage, gois=gois, progress_bar=progress_bar): + for gene, trids, _ in transcriptome.iter_transcripts( + genewise=True, + query=query, + region=region, + min_coverage=min_coverage, + max_coverage=max_coverage, + gois=gois, + progress_bar=progress_bar, + ): tr_ids.setdefault(gene.id, [[], []])[0] = trids if ref_query: - for gene, trids, _ in transcriptome.iter_ref_transcripts(genewise=True, query=ref_query, region=region, gois=gois, progress_bar=progress_bar): + for gene, trids, _ in transcriptome.iter_ref_transcripts( + genewise=True, + query=ref_query, + region=region, + gois=gois, + progress_bar=progress_bar, + ): tr_ids.setdefault(gene.id, [[], []])[1] = trids sequences = [] @@ -135,22 +203,37 @@ def get_hmmer_sequences(transcriptome, genome_fn, aa_alphabet, query=True, ref_q gene = transcriptome[gene_id] seqs = {} for source in range(2): - for trid, seq in gene.get_sequence(genome_fh, tr_ids[gene_id][source], protein=True, reference=source).items(): + for trid, seq in gene.get_sequence( + genome_fh, tr_ids[gene_id][source], protein=True, reference=source + ).items(): seqs.setdefault(seq, []).append((gene_id, source, trid)) for seq, seqnames in seqs.items(): # Hack: use "name" attribute to store an integer. # Must be string encoded since it is interpreted as 0 terminated string and thus truncated - text_seq = pyhmmer.easel.TextSequence(sequence=seq, name=bytes(str(len(sequences)), 'utf-8')) + text_seq = pyhmmer.easel.TextSequence( + sequence=seq, name=bytes(str(len(sequences)), "utf-8") + ) sequences.append(text_seq.digitize(aa_alphabet)) seq_ids.append(seqnames) return sequences, seq_ids + # function of isoseq.Transcriptome -def add_hmmer_domains(self, domain_models, genome, query=True, ref_query=False, region=None, - min_coverage=None, max_coverage=None, gois=None, progress_bar=False): - '''Align domains to protein sequences using pyhmmer and add them to the transcript isoforms. +def add_hmmer_domains( + self, + domain_models, + genome, + query=True, + ref_query=False, + region=None, + min_coverage=None, + max_coverage=None, + gois=None, + progress_bar=False, +): + """Align domains to protein sequences using pyhmmer and add them to the transcript isoforms. :param domain_models: The domain models and metadata, imported by "isotools.domains.import_hmmer_models" function :param genome: Filename of genome fasta file, or Fasta @@ -160,19 +243,34 @@ def add_hmmer_domains(self, domain_models, genome, query=True, ref_query=False, :param min_coverage: The minimum coverage threshold. Transcripts with less reads in total are ignored. :param max_coverage: The maximum coverage threshold. Transcripts with more reads in total are ignored. :param progress_bar: Print progress bars. - ''' + """ metadata, models = domain_models # 1) get the protein sequences pipeline = pyhmmer.plan7.Pipeline(models[0].alphabet) - logging.info('extracting protein sequences...') - sequences, seq_ids = get_hmmer_sequences(self, genome, models[0].alphabet, query, ref_query, region=region, - min_coverage=min_coverage, max_coverage=max_coverage, gois=gois, progress_bar=False) - logging.info(f'found {len(sequences)} different protein sequences from {sum(len(idL) for idL in seq_ids)} coding transcripts.') + logging.info("extracting protein sequences...") + sequences, seq_ids = get_hmmer_sequences( + self, + genome, + models[0].alphabet, + query, + ref_query, + region=region, + min_coverage=min_coverage, + max_coverage=max_coverage, + gois=gois, + progress_bar=False, + ) + logging.info( + f"found {len(sequences)} different protein sequences from {sum(len(idL) for idL in seq_ids)} coding transcripts." + ) # 2) align domain models to sequences - logging.info(f'aligning {len(models)} hmmer domain models to protein sequences...') - hits = {hmm.accession.decode(): pipeline.search_hmm(hmm, sequences) for hmm in tqdm(models, disable=not progress_bar, unit='domains')} + logging.info(f"aligning {len(models)} hmmer domain models to protein sequences...") + hits = { + hmm.accession.decode(): pipeline.search_hmm(hmm, sequences) + for hmm in tqdm(models, disable=not progress_bar, unit="domains") + } # 3) sort domains by gene/source/transcript domains = {} @@ -182,9 +280,17 @@ def add_hmmer_domains(self, domain_models, genome, query=True, ref_query=False, seq_nr = int(h.name.decode()) for domL in h.domains: ali = domL.alignment - transcript_pos = (ali.target_from*3, ali.target_to*3) + transcript_pos = (ali.target_from * 3, ali.target_to * 3) domains.setdefault(seq_nr, []).append( - (pfam_acc, infos['ID'], infos['TP'], transcript_pos, h.score, h.pvalue)) + ( + pfam_acc, + infos["ID"], + infos["TP"], + transcript_pos, + h.score, + h.pvalue, + ) + ) # print(f'{h.name}\t{hmm.name}:\t{ali.target_from}-{ali.target_to}') #which sequence? # 4) add domains to transcripts @@ -193,20 +299,51 @@ def add_hmmer_domains(self, domain_models, genome, query=True, ref_query=False, for seq_nr, domL in domains.items(): for gene_id, reference, transcript_id in seq_ids[seq_nr]: gene = self[gene_id] - transcript = gene.ref_transcripts[transcript_id] if reference else gene.transcripts[transcript_id] + transcript = ( + gene.ref_transcripts[transcript_id] + if reference + else gene.transcripts[transcript_id] + ) # get the genomic position of the domain boundaries - orf = sorted(gene.find_transcript_positions(transcript_id, transcript.get('CDS', transcript.get('ORF'))[:2], reference=reference)) - pos_map = genomic_position([p+orf[0] for dom in domL for p in dom[3]], transcript['exons'], gene.strand == '-') - trdom = tuple((*dom[:4], (pos_map[dom[3][0]+orf[0]], pos_map[dom[3][1]+orf[0]]), *dom[4:]) for dom in domL) - transcript.setdefault('domain', {})['hmmer'] = trdom + orf = sorted( + gene.find_transcript_positions( + transcript_id, + transcript.get("CDS", transcript.get("ORF"))[:2], + reference=reference, + ) + ) + pos_map = genomic_position( + [p + orf[0] for dom in domL for p in dom[3]], + transcript["exons"], + gene.strand == "-", + ) + trdom = tuple( + ( + *dom[:4], + (pos_map[dom[3][0] + orf[0]], pos_map[dom[3][1] + orf[0]]), + *dom[4:], + ) + for dom in domL + ) + transcript.setdefault("domain", {})["hmmer"] = trdom tr_count[reference] += 1 dom_count[reference] += len(domL) - logger.info(f'found domains at {dom_count[1]} loci for {tr_count[1]} reference transcripts ' + - f'and at {dom_count[0]} loci for {tr_count[0]} long read transcripts.') - - -def add_annotation_domains(self, annotation, category, id_col='uniProtId', name_col='name', inframe=True, progress_bar=False): - '''Annotate isoforms with protein domains from uniprot ucsc table files. + logger.info( + f"found domains at {dom_count[1]} loci for {tr_count[1]} reference transcripts " + + f"and at {dom_count[0]} loci for {tr_count[0]} long read transcripts." + ) + + +def add_annotation_domains( + self, + annotation, + category, + id_col="uniProtId", + name_col="name", + inframe=True, + progress_bar=False, +): + """Annotate isoforms with protein domains from uniprot ucsc table files. This function adds protein domains and other protein annotation to the transcripts. Annotation tables can be retrieved from https://genome.ucsc.edu/cgi-bin/hgTables. Select @@ -219,119 +356,206 @@ def add_annotation_domains(self, annotation, category, id_col='uniProtId', name_ :param inframe: If set True (default), only annotations starting in frame are added to the transcript. :param append: If set True, the annotation is added to existing annotation. This may lead to duplicate entries. By default, annotation of the same category is removed before annotation is added. - :param progress_bar: If set True, the progress is depicted with a progress bar.''' + :param progress_bar: If set True, the progress is depicted with a progress bar.""" domain_count = 0 # clear domains of that category if isinstance(annotation, str): - anno = pd.read_csv(annotation, sep='\t', low_memory=False) + anno = pd.read_csv(annotation, sep="\t", low_memory=False) elif isinstance(annotation, pd.DataFrame): anno = annotation else: raise ValueError('"annotation" should be file name or pandas.DataFrame object') - anno = anno.rename({'#chrom': 'chrom'}, axis=1) - not_found = [col for col in ['chrom', 'chromStart', 'chromEnd', 'chromStarts', 'blockSizes', name_col] if col not in anno.columns] - assert len(not_found) == 0, f'did not find the following columns in the annotation table: {", ".join(not_found)}' - for _, row in tqdm(anno.iterrows(), total=len(anno), disable=not progress_bar, unit='domains'): - if row['chrom'] not in self.chromosomes: + anno = anno.rename({"#chrom": "chrom"}, axis=1) + not_found = [ + col + for col in [ + "chrom", + "chromStart", + "chromEnd", + "chromStarts", + "blockSizes", + name_col, + ] + if col not in anno.columns + ] + assert ( + len(not_found) == 0 + ), f'did not find the following columns in the annotation table: {", ".join(not_found)}' + for _, row in tqdm( + anno.iterrows(), total=len(anno), disable=not progress_bar, unit="domains" + ): + if row["chrom"] not in self.chromosomes: continue - for gene in self.iter_genes(region=(row['chrom'], row.chromStart, row.chromEnd)): - block_starts, block_sizes = list(map(int, row.chromStarts.split(','))), list(map(int, row.blockSizes.split(','))) + for gene in self.iter_genes( + region=(row["chrom"], row.chromStart, row.chromEnd) + ): + block_starts, block_sizes = list( + map(int, row.chromStarts.split(",")) + ), list(map(int, row.blockSizes.split(","))) blocks = [] for start, length in zip(block_starts, block_sizes): - if not blocks or row.chromStart+start > blocks[-1][1]: - blocks.append([row.chromStart+start, row.chromStart+start+length]) + if not blocks or row.chromStart + start > blocks[-1][1]: + blocks.append( + [row.chromStart + start, row.chromStart + start + length] + ) else: - blocks[-1][1] = row.chromStart+start+length + blocks[-1][1] = row.chromStart + start + length for ref in range(2): transcripts = gene.ref_transcripts if ref else gene.transcripts if not transcripts: continue sg = gene.ref_segment_graph if ref else gene.segment_graph - transcript_ids = [transcript_id for transcript_id in sg.search_transcript(blocks, complete=False, include_ends=True) - if 'ORF' in transcripts[transcript_id] or 'CDS' in transcripts[transcript_id]] + transcript_ids = [ + transcript_id + for transcript_id in sg.search_transcript( + blocks, complete=False, include_ends=True + ) + if "ORF" in transcripts[transcript_id] + or "CDS" in transcripts[transcript_id] + ] for transcript_id in transcript_ids: transcript = transcripts[transcript_id] try: - orf_pos = sorted(gene.find_transcript_positions(transcript_id, transcript.get('CDS', transcript.get('ORF'))[:2], reference=ref)) - domain_pos = sorted(gene.find_transcript_positions(transcript_id, (row.chromStart, row.chromEnd), reference=ref)) + orf_pos = sorted( + gene.find_transcript_positions( + transcript_id, + transcript.get("CDS", transcript.get("ORF"))[:2], + reference=ref, + ) + ) + domain_pos = sorted( + gene.find_transcript_positions( + transcript_id, + (row.chromStart, row.chromEnd), + reference=ref, + ) + ) except TypeError: # > not supported for None, None continue - if not (orf_pos[0] <= domain_pos[0] and domain_pos[1] <= orf_pos[1]): # check within ORF + if not ( + orf_pos[0] <= domain_pos[0] and domain_pos[1] <= orf_pos[1] + ): # check within ORF continue - if inframe and (domain_pos[0]-orf_pos[0]) % 3 != 0: # check inframe + if ( + inframe and (domain_pos[0] - orf_pos[0]) % 3 != 0 + ): # check inframe continue domain_count += 1 - dom_pos = (domain_pos[0]-orf_pos[0], domain_pos[1]-orf_pos[0]) - dom_vals = (row[id_col], row[name_col], category, dom_pos, (row.chromStart, row.chromEnd)) + dom_pos = (domain_pos[0] - orf_pos[0], domain_pos[1] - orf_pos[0]) + dom_vals = ( + row[id_col], + row[name_col], + category, + dom_pos, + (row.chromStart, row.chromEnd), + ) # check if present already - if dom_vals not in transcript.setdefault('domain', {}).setdefault('annotation', []): - transcript['domain']['annotation'].append(dom_vals) - - logger.info(f'found domains at {domain_count} transcript loci') - - -def get_interpro_domains(seqs, email, baseUrl='http://www.ebi.ac.uk/Tools/services/rest/iprscan5', progress_bar=True, max_jobs=25, poll_time=5): - '''Request domains from ebi interpro REST API. - - Returns a list of the json responses as received, one for each requested sequenced.''' + if dom_vals not in transcript.setdefault("domain", {}).setdefault( + "annotation", [] + ): + transcript["domain"]["annotation"].append(dom_vals) + + logger.info(f"found domains at {domain_count} transcript loci") + + +def get_interpro_domains( + seqs, + email, + baseUrl="http://www.ebi.ac.uk/Tools/services/rest/iprscan5", + progress_bar=True, + max_jobs=25, + poll_time=5, +): + """Request domains from ebi interpro REST API. + + Returns a list of the json responses as received, one for each requested sequenced. + """ # examples at https://raw.githubusercontent.com/ebi-wp/webservice-clients/master/python/iprscan5.py - requestUrl = baseUrl + u'/run/' + requestUrl = baseUrl + "/run/" current_jobs = {} if isinstance(seqs, str): seqs = [seqs] - domains = [None]*len(seqs) + domains = [None] * len(seqs) i = 0 - with tqdm(unit='proteins', disable=not progress_bar, total=len(seqs)) as pbar: + with tqdm(unit="proteins", disable=not progress_bar, total=len(seqs)) as pbar: try: while seqs or current_jobs: # still something to do while seqs and len(current_jobs) < max_jobs: # start more jobs - params = {u'email': email, u'sequence': seqs.pop()} + params = {"email": email, "sequence": seqs.pop()} resp = requests.post(requestUrl, data=params) # todo: error handling - what if request fails? current_jobs[resp.content.decode()] = i - pbar.set_description(f'waiting for {len(current_jobs)} jobs') + pbar.set_description(f"waiting for {len(current_jobs)} jobs") i += 1 time.sleep(poll_time) done = set() for job_id, idx in current_jobs.items(): # check the current jobs - url = baseUrl + u'/status/' + job_id + url = baseUrl + "/status/" + job_id resp = requests.get(url) if not resp.ok: # todo: error handling: e.g. timeout error? continue status = resp.content.decode() - if status in ('PENDING', 'RUNNING'): + if status in ("PENDING", "RUNNING"): continue - elif status == 'FINISHED': - url = baseUrl + u'/result/' + job_id + '/json' + elif status == "FINISHED": + url = baseUrl + "/result/" + job_id + "/json" resp = requests.get(url) if resp.ok: - domains[idx] = resp.json()['results'] # else? + domains[idx] = resp.json()["results"] # else? pbar.update(1) else: - domains[idx] = [{'status': 'FAILED', 'reason': 'resp_not_ok', 'jobid': job_id}] + domains[idx] = [ + { + "status": "FAILED", + "reason": "resp_not_ok", + "jobid": job_id, + } + ] done.add(job_id) - elif status == 'FAILED': # try again? - logger.warning(f'Failed to get response for sequence {idx}') + elif status == "FAILED": # try again? + logger.warning(f"Failed to get response for sequence {idx}") done.add(job_id) - domains[idx] = [{'status': 'FAILED', 'reason': 'job failed', 'jobid': job_id}] + domains[idx] = [ + { + "status": "FAILED", + "reason": "job failed", + "jobid": job_id, + } + ] else: - logger.warning(f'unhandled status for sequence {idx}: jobid={job_id}') + logger.warning( + f"unhandled status for sequence {idx}: jobid={job_id}" + ) for job_id in done: # remove the finished jobs current_jobs.pop(job_id) - pbar.set_description(f'waiting for {len(current_jobs)} jobs') + pbar.set_description(f"waiting for {len(current_jobs)} jobs") except KeyboardInterrupt: - logger.warning(f'Interrupting retrieval of {len(current_jobs)} jobs: [{",".join(current_jobs)}]') # give the user the chance to check the jobs + logger.warning( + f'Interrupting retrieval of {len(current_jobs)} jobs: [{",".join(current_jobs)}]' + ) # give the user the chance to check the jobs return domains + # method of isotools.Gene -def add_interpro_domains(self, genome, email, baseUrl='http://www.ebi.ac.uk/Tools/services/rest/iprscan5', max_jobs=25, poll_time=5, - query=True, ref_query=False, min_coverage=None, max_coverage=None, progress_bar=True): - '''Add domains to gene by webrequests to ebi interpro REST API. +def add_interpro_domains( + self, + genome, + email, + baseUrl="http://www.ebi.ac.uk/Tools/services/rest/iprscan5", + max_jobs=25, + poll_time=5, + query=True, + ref_query=False, + min_coverage=None, + max_coverage=None, + progress_bar=True, +): + """Add domains to gene by webrequests to ebi interpro REST API. This function adds protein domains from interpro to the transcripts. Note that these rquest may take around 60 seconds per sequence. @@ -345,38 +569,73 @@ def add_interpro_domains(self, genome, email, baseUrl='http://www.ebi.ac.uk/Tool :param ref_query: Query string to select the reference transcripts, or True/False to include/exclude all transcripts. :param min_coverage: The minimum coverage threshold. Transcripts with less reads in total are ignored. :param max_coverage: The maximum coverage threshold. Transcripts with more reads in total are ignored. - :param progress_bar: If set True, the progress is depicted with a progress bar.''' + :param progress_bar: If set True, the progress is depicted with a progress bar.""" seqs = {} # seq -> transcript_ids dict, to avoid requesting the same sequence if query: - transcript_ids = self.filter_transcripts(query=query, min_coverage=min_coverage, max_coverage=max_coverage) - for transcript_id, seq in self.get_sequence(genome, transcript_ids, protein=True).items(): - seqs.setdefault(seq, {}).setdefault('isotools', []).append(transcript_id) + transcript_ids = self.filter_transcripts( + query=query, min_coverage=min_coverage, max_coverage=max_coverage + ) + for transcript_id, seq in self.get_sequence( + genome, transcript_ids, protein=True + ).items(): + seqs.setdefault(seq, {}).setdefault("isotools", []).append(transcript_id) if ref_query: ref_transcript_ids = self.filter_ref_transcripts(query=ref_query) - for transcript_id, seq in self.get_sequence(genome, ref_transcript_ids, protein=True, reference=True).items(): - seqs.setdefault(seq, {}).setdefault('reference', []).append(transcript_id) - - dom_results = get_interpro_domains(list(seqs.keys()), email, baseUrl, progress_bar, max_jobs, poll_time) + for transcript_id, seq in self.get_sequence( + genome, ref_transcript_ids, protein=True, reference=True + ).items(): + seqs.setdefault(seq, {}).setdefault("reference", []).append(transcript_id) + + dom_results = get_interpro_domains( + list(seqs.keys()), email, baseUrl, progress_bar, max_jobs, poll_time + ) for i, (dom,) in enumerate(dom_results): - if 'matches' not in dom: - logger.warning(f'no response for sequence of {list(seqs.values())[i]}') + if "matches" not in dom: + logger.warning(f"no response for sequence of {list(seqs.values())[i]}") continue domL = [] - for match in dom['matches']: - for loc in match['locations']: - entry = match['signature'].get('entry') - domL.append((str(match['signature']['accession']), # short name - str(match['signature']['name']), - entry.get('type', "unknown") if entry else "unknown", # type - (loc['start']*3, loc['end']*3), # position - loc.get('hmmBounds'))) # completeness + for match in dom["matches"]: + for loc in match["locations"]: + entry = match["signature"].get("entry") + domL.append( + ( + str(match["signature"]["accession"]), # short name + str(match["signature"]["name"]), + entry.get("type", "unknown") if entry else "unknown", # type + (loc["start"] * 3, loc["end"] * 3), # position + loc.get("hmmBounds"), + ) + ) # completeness # todo: potentially add more relevant information here for reference in range(2): - for transcript_id in seqs[dom['sequence']].get('reference' if reference else 'isotools', []): - transcript = self.ref_transcripts[transcript_id] if reference else self.transcripts[transcript_id] - orf = sorted(self.find_transcript_positions(transcript_id, transcript.get('CDS', transcript.get('ORF'))[:2], reference=reference)) - pos_map = genomic_position([p+orf[0] for dom in domL for p in dom[3]], transcript['exons'], self.strand == '-') - trdom = tuple((*dom[:4], (pos_map[dom[3][0]+orf[0]], pos_map[dom[3][1]+orf[0]]), *dom[4:]) for dom in domL) - - transcript.setdefault('domain', {})['interpro'] = trdom + for transcript_id in seqs[dom["sequence"]].get( + "reference" if reference else "isotools", [] + ): + transcript = ( + self.ref_transcripts[transcript_id] + if reference + else self.transcripts[transcript_id] + ) + orf = sorted( + self.find_transcript_positions( + transcript_id, + transcript.get("CDS", transcript.get("ORF"))[:2], + reference=reference, + ) + ) + pos_map = genomic_position( + [p + orf[0] for dom in domL for p in dom[3]], + transcript["exons"], + self.strand == "-", + ) + trdom = tuple( + ( + *dom[:4], + (pos_map[dom[3][0] + orf[0]], pos_map[dom[3][1] + orf[0]]), + *dom[4:], + ) + for dom in domL + ) + + transcript.setdefault("domain", {})["interpro"] = trdom diff --git a/src/isotools/gene.py b/src/isotools/gene.py index 8c29dd3..a516b6d 100644 --- a/src/isotools/gene.py +++ b/src/isotools/gene.py @@ -14,8 +14,19 @@ from .splice_graph import SegmentGraph from .short_read import Coverage from ._transcriptome_filter import SPLICE_CATEGORY -from ._utils import pairwise, _filter_event, find_orfs, DEFAULT_KOZAK_PWM, kozak_score, smooth, get_quantiles, \ - _filter_function, pairwise_event_test, prepare_contingency_table, cmp_dist +from ._utils import ( + pairwise, + _filter_event, + find_orfs, + DEFAULT_KOZAK_PWM, + kozak_score, + smooth, + get_quantiles, + _filter_function, + pairwise_event_test, + prepare_contingency_table, + cmp_dist, +) from typing import Any, Literal, Optional, TypedDict, TYPE_CHECKING if TYPE_CHECKING: @@ -24,7 +35,8 @@ from ._utils import ASEvent import logging -logger = logging.getLogger('isotools') + +logger = logging.getLogger("isotools") class SQANTI_classification(TypedDict): @@ -37,21 +49,24 @@ class SQANTI_classification(TypedDict): polyA_motif_found: bool ratio_TSS: float + class Transcript(TypedDict, total=False): chr: str - strand: Literal['+', '-'] + strand: Literal["+", "-"] exons: list[tuple[int, int]] coverage: dict[str, int] - 'The coverage of the transcript in each sample.' + "The coverage of the transcript in each sample." TSS: dict[str, dict[int, int]] - 'The TSS of each sample with their coverage.' + "The TSS of each sample with their coverage." PAS: dict[str, dict[int, int]] - 'The PAS of each sample with their coverage.' + "The PAS of each sample with their coverage." clipping: dict[str, dict[str, int]] - annotation: tuple[int, dict[str, Any]] # TODO: Switch the dict to a TypedDict, Replace novelty with Enum + annotation: tuple[ + int, dict[str, Any] + ] # TODO: Switch the dict to a TypedDict, Replace novelty with Enum "The annotation of the transcript. The first element is the novelty class (0=FSM,1=ISM,2=NIC,3=NNC,4=Novel gene), the second a dictionary with the subcategories." reads: dict[str, list[str]] - 'sample names as keys, list of reads as values.' + "sample names as keys, list of reads as values." novel_splice_sites: list[int] TSS_unified: dict[str, dict[int, int]] PAS_unified: dict[str, dict[int, int]] @@ -63,6 +78,7 @@ class Transcript(TypedDict, total=False): fuzzy_junction: Any sqanti_classification: SQANTI_classification + class RefTranscript(TypedDict, total=False): transcript_id: str transcript_type: str @@ -71,193 +87,300 @@ class RefTranscript(TypedDict, total=False): exons: list[tuple[int, int]] CDS: tuple[int, int] + class ReferenceData(TypedDict, total=False): segment_graph: SegmentGraph transcripts: list[RefTranscript] + class GeneData(TypedDict, total=False): ID: str name: str chr: str - strand: Literal['+', '-'] + strand: Literal["+", "-"] short_reads: list[Coverage] coverage: np.ndarray reference: ReferenceData segment_graph: SegmentGraph transcripts: list[Transcript] + class Gene(Interval): - 'This class stores all gene information and transcripts. It is derived from intervaltree.Interval.' - required_infos = ['ID', 'chr', 'strand'] + "This class stores all gene information and transcripts. It is derived from intervaltree.Interval." + + required_infos = ["ID", "chr", "strand"] data: GeneData - _transcriptome: 'Transcriptome' + _transcriptome: "Transcriptome" # initialization def __new__(cls, begin, end, data: GeneData, transcriptome): - return super().__new__(cls, begin, end, data) # required as Interval (and Gene) is immutable + return super().__new__( + cls, begin, end, data + ) # required as Interval (and Gene) is immutable def __init__(self, begin, end, data: GeneData, transcriptome: Transcriptome): self._transcriptome = transcriptome def __str__(self): - return 'Gene {} {}({}), {} reference transcripts, {} expressed transcripts'.format( - self.name, self.region, self.strand, self.n_ref_transcripts, self.n_transcripts) + return ( + "Gene {} {}({}), {} reference transcripts, {} expressed transcripts".format( + self.name, + self.region, + self.strand, + self.n_ref_transcripts, + self.n_transcripts, + ) + ) def __repr__(self): return object.__repr__(self) - from ._gene_plots import sashimi_plot, gene_track, sashimi_plot_short_reads, sashimi_figure, plot_domains + from ._gene_plots import ( + sashimi_plot, + gene_track, + sashimi_plot_short_reads, + sashimi_figure, + plot_domains, + ) from .domains import add_interpro_domains def short_reads(self, idx): - '''Returns the short read coverage profile for a short read sample. + """Returns the short read coverage profile for a short read sample. :param idx: The index of the short read sample. - :returns: The short read coverage profile.''' + :returns: The short read coverage profile.""" try: # raises key_error if no short reads added - return self.data['short_reads'][idx] + return self.data["short_reads"][idx] except (KeyError, IndexError): - srdf = self._transcriptome.infos['short_reads'] - self.data.setdefault('short_reads', []) - for i in range(len(self.data['short_reads']), len(srdf)): - self.data['short_reads'].append(Coverage.from_bam(srdf.file[i], self)) - return self.data['short_reads'][idx] + srdf = self._transcriptome.infos["short_reads"] + self.data.setdefault("short_reads", []) + for i in range(len(self.data["short_reads"]), len(srdf)): + self.data["short_reads"].append(Coverage.from_bam(srdf.file[i], self)) + return self.data["short_reads"][idx] def correct_fuzzy_junctions(self, transcript: Transcript, size, modify=True): - '''Corrects for splicing shifts. + """Corrects for splicing shifts. - This function looks for "shifted junctions", e.g. same difference compared to reference annotation at both donor and acceptor) - presumably caused by ambiguous alignments. In these cases the positions are adapted to the reference position (if modify is set). + This function looks for "shifted junctions", e.g. same difference compared to reference annotation at both donor and acceptor) + presumably caused by ambiguous alignments. In these cases the positions are adapted to the reference position (if modify is set). - :param transcript_id: The index of the transcript to be checked. - :param size: The maximum shift to be corrected. - :param modify: If set, the exon positions are corrected according to the reference. - :returns: A dictionary with the exon id as keys and the shifted bases as values.''' + :param transcript_id: The index of the transcript to be checked. + :param size: The maximum shift to be corrected. + :param modify: If set, the exon positions are corrected according to the reference. + :returns: A dictionary with the exon id as keys and the shifted bases as values. + """ - exons = transcript['exons'] + exons = transcript["exons"] shifts = self.ref_segment_graph.fuzzy_junction(exons, size) if shifts and modify: for i, sh in shifts.items(): - if exons[i][0] <= exons[i][1] + sh and exons[i + 1][0] + sh <= exons[i + 1][1]: + if ( + exons[i][0] <= exons[i][1] + sh + and exons[i + 1][0] + sh <= exons[i + 1][1] + ): exons[i][1] += sh exons[i + 1][0] += sh - transcript['exons'] = [e for e in exons if e[0] < e[1]] # remove zero length exons + transcript["exons"] = [ + e for e in exons if e[0] < e[1] + ] # remove zero length exons return shifts - def _to_gtf(self, transcript_ids, ref_transcript_ids=None, source='isoseq', ref_source='annotation'): - '''Creates the gtf lines of the gene as strings.''' - donotshow = {'transcripts', 'short_exons', 'segment_graph'} - info = {'gene_id': self.id, 'gene_name': self.name} + def _to_gtf( + self, + transcript_ids, + ref_transcript_ids=None, + source="isoseq", + ref_source="annotation", + ): + """Creates the gtf lines of the gene as strings.""" + donotshow = {"transcripts", "short_exons", "segment_graph"} + info = {"gene_id": self.id, "gene_name": self.name} lines = [None] starts = [] ends = [] ref_fsm = [] for i in transcript_ids: transcript = self.transcripts[i] - info['transcript_id'] = f'{info["gene_id"]}_{i}' - starts.append(transcript['exons'][0][0] + 1) - ends.append(transcript['exons'][-1][1]) + info["transcript_id"] = f'{info["gene_id"]}_{i}' + starts.append(transcript["exons"][0][0] + 1) + ends.append(transcript["exons"][-1][1]) transcript_info = info.copy() - if 'downstream_A_content' in transcript: - transcript_info['downstream_A_content'] = f'{transcript["downstream_A_content"]:0.3f}' - if transcript['annotation'][0] == 0: # FSM + if "downstream_A_content" in transcript: + transcript_info["downstream_A_content"] = ( + f'{transcript["downstream_A_content"]:0.3f}' + ) + if transcript["annotation"][0] == 0: # FSM refinfo = {} - for refid in transcript['annotation'][1]['FSM']: + for refid in transcript["annotation"][1]["FSM"]: ref_fsm.append(refid) for k in self.ref_transcripts[refid]: - if k == 'exons': + if k == "exons": continue - elif k == 'CDS': - if self.strand == '+': - cds_start, cds_end = self.ref_transcripts[refid]['CDS'] + elif k == "CDS": + if self.strand == "+": + cds_start, cds_end = self.ref_transcripts[refid]["CDS"] else: - cds_end, cds_start = self.ref_transcripts[refid]['CDS'] - refinfo.setdefault('CDS_start', []).append(str(cds_start)) - refinfo.setdefault('CDS_end', []).append(str(cds_end)) + cds_end, cds_start = self.ref_transcripts[refid]["CDS"] + refinfo.setdefault("CDS_start", []).append(str(cds_start)) + refinfo.setdefault("CDS_end", []).append(str(cds_end)) else: - refinfo.setdefault(k, []).append(str(self.ref_transcripts[refid][k])) + refinfo.setdefault(k, []).append( + str(self.ref_transcripts[refid][k]) + ) for k, vlist in refinfo.items(): - transcript_info[f'ref_{k}'] = ','.join(vlist) + transcript_info[f"ref_{k}"] = ",".join(vlist) else: - transcript_info['novelty'] = ','.join(k for k in transcript['annotation'][1]) - lines.append((self.chrom, source, 'transcript', transcript['exons'][0][0] + 1, transcript['exons'][-1][1], '.', - self.strand, '.', '; '.join(f'{k} "{v}"' for k, v in transcript_info.items()))) - noncanonical = transcript.get('noncanonical_splicing', []) - for enr, pos in enumerate(transcript['exons']): + transcript_info["novelty"] = ",".join( + k for k in transcript["annotation"][1] + ) + lines.append( + ( + self.chrom, + source, + "transcript", + transcript["exons"][0][0] + 1, + transcript["exons"][-1][1], + ".", + self.strand, + ".", + "; ".join(f'{k} "{v}"' for k, v in transcript_info.items()), + ) + ) + noncanonical = transcript.get("noncanonical_splicing", []) + for enr, pos in enumerate(transcript["exons"]): exon_info = info.copy() - exon_info['exon_id'] = f'{info["gene_id"]}_{i}_{enr}' + exon_info["exon_id"] = f'{info["gene_id"]}_{i}_{enr}' if enr in noncanonical: - exon_info['noncanonical_donor'] = noncanonical[enr][:2] - if enr+1 in noncanonical: - exon_info['noncanonical_acceptor'] = noncanonical[enr+1][2:] - lines.append((self.chrom, source, 'exon', pos[0] + 1, pos[1], '.', self.strand, '.', '; '.join(f'{k} "{v}"' for k, v in exon_info.items()))) + exon_info["noncanonical_donor"] = noncanonical[enr][:2] + if enr + 1 in noncanonical: + exon_info["noncanonical_acceptor"] = noncanonical[enr + 1][2:] + lines.append( + ( + self.chrom, + source, + "exon", + pos[0] + 1, + pos[1], + ".", + self.strand, + ".", + "; ".join(f'{k} "{v}"' for k, v in exon_info.items()), + ) + ) if ref_transcript_ids: # add reference transcripts not covered by FSM for i, transcript in enumerate(self.ref_transcripts): if i in ref_fsm: continue - starts.append(transcript['exons'][0][0] + 1) - ends.append(transcript['exons'][-1][1]) - info['transcript_id'] = f'{info["gene_id"]}_ref{i}' + starts.append(transcript["exons"][0][0] + 1) + ends.append(transcript["exons"][-1][1]) + info["transcript_id"] = f'{info["gene_id"]}_ref{i}' refinfo = info.copy() for k in transcript: - if k == 'exons': + if k == "exons": continue - elif k == 'CDS': - if self.strand == '+': - cds_start, cds_end = transcript['CDS'] + elif k == "CDS": + if self.strand == "+": + cds_start, cds_end = transcript["CDS"] else: - cds_end, cds_start = transcript['CDS'] - refinfo['CDS_start'] = str(cds_start) - refinfo['CDS_end'] = str(cds_end) + cds_end, cds_start = transcript["CDS"] + refinfo["CDS_start"] = str(cds_start) + refinfo["CDS_end"] = str(cds_end) else: refinfo.setdefault(k, []).append(str(transcript[k])) - lines.append((self.chrom, ref_source, 'transcript', transcript['exons'][0][0] + 1, transcript['exons'][-1][1], '.', - self.strand, '.', '; '.join(f'{k} "{v}"' for k, v in refinfo.items()))) - for enr, pos in enumerate(transcript['exons']): + lines.append( + ( + self.chrom, + ref_source, + "transcript", + transcript["exons"][0][0] + 1, + transcript["exons"][-1][1], + ".", + self.strand, + ".", + "; ".join(f'{k} "{v}"' for k, v in refinfo.items()), + ) + ) + for enr, pos in enumerate(transcript["exons"]): exon_info = info.copy() exon_id = f'{info["gene_id"]}_ref{i}_{enr}' - lines.append((self.chrom, ref_source, 'exon', pos[0] + 1, pos[1], '.', self.strand, '.', f'exon_id "{exon_id}"')) + lines.append( + ( + self.chrom, + ref_source, + "exon", + pos[0] + 1, + pos[1], + ".", + self.strand, + ".", + f'exon_id "{exon_id}"', + ) + ) if len(lines) > 1: # add gene line - if 'reference' in self.data: + if "reference" in self.data: # add reference gene specific fields - info.update({k: v for k, v in self.data['reference'].items() if k not in donotshow}) - lines[0] = (self.chrom, source, 'gene', min(starts), max(ends), '.', self.strand, '.', '; '.join(f'{k} "{v}"' for k, v in info.items())) + info.update( + { + k: v + for k, v in self.data["reference"].items() + if k not in donotshow + } + ) + lines[0] = ( + self.chrom, + source, + "gene", + min(starts), + max(ends), + ".", + self.strand, + ".", + "; ".join(f'{k} "{v}"' for k, v in info.items()), + ) return lines return [] def add_noncanonical_splicing(self, genome_fh): - '''Add information on noncanonical splicing. + """Add information on noncanonical splicing. For all transcripts of the gene, scan for noncanonical (i.e. not GT-AG) splice sites. If noncanonical splice sites are present, the corresponding intron index (in genomic orientation) and the sequence i.e. the di-nucleotides of donor and acceptor as XX-YY string are stored in the "noncannoncical_splicing" field of the transcript dicts. True noncanonical splicing is rare, thus it might indicate technical artifacts (template switching, misalignment, ...) - :param genome_fh: A file handle of the genome fastA file.''' + :param genome_fh: A file handle of the genome fastA file.""" ss_seq = {} for transcript in self.transcripts: - pos = [(transcript['exons'][i][1], transcript['exons'][i + 1][0] - 2) for i in range(len(transcript['exons']) - 1)] - new_ss_seq = {site: genome_fh.fetch(self.chrom, site, site + 2).upper() for intron in pos for site in intron if site not in ss_seq} + pos = [ + (transcript["exons"][i][1], transcript["exons"][i + 1][0] - 2) + for i in range(len(transcript["exons"]) - 1) + ] + new_ss_seq = { + site: genome_fh.fetch(self.chrom, site, site + 2).upper() + for intron in pos + for site in intron + if site not in ss_seq + } if new_ss_seq: ss_seq.update(new_ss_seq) - if self.strand == '+': + if self.strand == "+": sj_seq = [ss_seq[d] + ss_seq[a] for d, a in pos] else: sj_seq = [reverse_complement(ss_seq[d] + ss_seq[a]) for d, a in pos] - nc = [(i, seq) for i, seq in enumerate(sj_seq) if seq != 'GTAG'] + nc = [(i, seq) for i, seq in enumerate(sj_seq) if seq != "GTAG"] if nc: - transcript['noncanonical_splicing'] = nc + transcript["noncanonical_splicing"] = nc def add_direct_repeat_len(self, genome_fh, delta=15, max_mm=2, wobble=2): - '''Computes direct repeat length. + """Computes direct repeat length. This function counts the number of consecutive equal bases at donor and acceptor sites of the splice junctions. This information is stored in the "direct_repeat_len" filed of the transcript dictionaries. @@ -266,82 +389,117 @@ def add_direct_repeat_len(self, genome_fh, delta=15, max_mm=2, wobble=2): :param genome_fh: The file handle to the genome fastA. :param delta: The maximum length of direct repeats that can be found. :param max_mm: The maximum length of direct repeats that can be found. - :param wobble: The maximum length of direct repeats that can be found.''' + :param wobble: The maximum length of direct repeats that can be found.""" intron_seq = {} score = {} for transcript in self.transcripts: - for intron in ((transcript['exons'][i][1], transcript['exons'][i + 1][0]) for i in range(len(transcript['exons']) - 1)): + for intron in ( + (transcript["exons"][i][1], transcript["exons"][i + 1][0]) + for i in range(len(transcript["exons"]) - 1) + ): for pos in intron: try: - intron_seq.setdefault(pos, genome_fh.fetch(self.chrom, pos - delta, pos + delta)) - except (ValueError, IndexError): # N padding at start/end of the chromosomes + intron_seq.setdefault( + pos, genome_fh.fetch(self.chrom, pos - delta, pos + delta) + ) + except ( + ValueError, + IndexError, + ): # N padding at start/end of the chromosomes chr_len = genome_fh.get_reference_length(self.chrom) - seq = genome_fh.fetch(self.chrom, max(0, pos - delta), min(chr_len, pos + delta)) + seq = genome_fh.fetch( + self.chrom, max(0, pos - delta), min(chr_len, pos + delta) + ) if pos - delta < 0: - seq = ''.join(['N'] * (pos - delta)) + seq + seq = "".join(["N"] * (pos - delta)) + seq if pos + delta > chr_len: - seq += ''.join(['N'] * (pos + delta - chr_len)) + seq += "".join(["N"] * (pos + delta - chr_len)) intron_seq.setdefault(pos, seq) if intron not in score: - score[intron] = repeat_len(intron_seq[intron[0]], intron_seq[intron[1]], wobble=wobble, max_mm=max_mm) + score[intron] = repeat_len( + intron_seq[intron[0]], + intron_seq[intron[1]], + wobble=wobble, + max_mm=max_mm, + ) for transcript in self.transcripts: - transcript['direct_repeat_len'] = [min(score[(exon_1[1], exon_2[0])], delta) for exon_1, exon_2 in pairwise(transcript['exons'])] + transcript["direct_repeat_len"] = [ + min(score[(exon_1[1], exon_2[0])], delta) + for exon_1, exon_2 in pairwise(transcript["exons"]) + ] def add_threeprime_a_content(self, genome_fh, length=30): - '''Adds the information of the genomic A content downstream the transcript. + """Adds the information of the genomic A content downstream the transcript. High values of genomic A content indicate internal priming and hence genomic origin of the LRTS read. This function populates the 'downstream_A_content' field of the transcript dictionaries. :param geneome_fh: A file handle for the indexed genome fastA file. :param length: The length of the downstream region to be considered. - ''' + """ a_content = {} - for transcript in (t for tL in (self.transcripts, self.ref_transcripts) for t in tL): - if self.strand == '+': - pos = transcript['exons'][-1][1] + for transcript in ( + t for tL in (self.transcripts, self.ref_transcripts) for t in tL + ): + if self.strand == "+": + pos = transcript["exons"][-1][1] else: - pos = transcript['exons'][0][0] - length + pos = transcript["exons"][0][0] - length if pos not in a_content: seq = genome_fh.fetch(self.chrom, max(0, pos), pos + length) - if self.strand == '+': - a_content[pos] = seq.upper().count('A') / length + if self.strand == "+": + a_content[pos] = seq.upper().count("A") / length else: - a_content[pos] = seq.upper().count('T') / length - transcript['downstream_A_content'] = a_content[pos] + a_content[pos] = seq.upper().count("T") / length + transcript["downstream_A_content"] = a_content[pos] def add_sqanti_classification(self, transcript_id: int, classification_row: Series): - assert transcript_id < self.n_transcripts, f'Transcript id {transcript_id} not found in gene {self.id}' - infos = classification_row[[ - 'dist_to_CAGE_peak', - 'within_CAGE_peak', - 'dist_to_polyA_site', - 'within_polyA_site', - 'polyA_motif', - 'polyA_dist', - 'polyA_motif_found', - 'ratio_TSS' - ]].to_dict() - self.transcripts[transcript_id]['sqanti_classification'] = infos - - def get_sequence(self, genome_fh, transcript_ids=None, reference=False, protein=False): - '''Returns the nucleotide sequence of the specified transcripts. + assert ( + transcript_id < self.n_transcripts + ), f"Transcript id {transcript_id} not found in gene {self.id}" + infos = classification_row[ + [ + "dist_to_CAGE_peak", + "within_CAGE_peak", + "dist_to_polyA_site", + "within_polyA_site", + "polyA_motif", + "polyA_dist", + "polyA_motif_found", + "ratio_TSS", + ] + ].to_dict() + self.transcripts[transcript_id]["sqanti_classification"] = infos + + def get_sequence( + self, genome_fh, transcript_ids=None, reference=False, protein=False + ): + """Returns the nucleotide sequence of the specified transcripts. :param genome_fh: The path to the genome fastA file, or FastaFile handle. :param transcript_ids: List of transcript ids for which the sequence are requested. :param reference: Specify whether the sequence is fetched for reference transcripts (True) or long read transcripts (False, default). - :param protein: Return protein sequences instead of transcript sequences. + :param protein: Return translated protein coding sequences instead of full transcript sequences. :returns: A dictionary of transcript ids and their sequences. - ''' - - trL = [(i, transcript) for i, transcript in enumerate(self.ref_transcripts if reference else self.transcripts) if transcript_ids is None or i in transcript_ids] + """ + + trL = [ + (i, transcript) + for i, transcript in enumerate( + self.ref_transcripts if reference else self.transcripts + ) + if transcript_ids is None or i in transcript_ids + ] if not trL: return {} - pos = (min(transcript['exons'][0][0] for _, transcript in trL), max(transcript['exons'][-1][1] for _, transcript in trL)) + pos = ( + min(transcript["exons"][0][0] for _, transcript in trL), + max(transcript["exons"][-1][1] for _, transcript in trL), + ) try: # assume its a FastaFile file handle seq = genome_fh.fetch(self.chrom, *pos) except AttributeError: @@ -350,71 +508,109 @@ def get_sequence(self, genome_fh, transcript_ids=None, reference=False, protein= seq = genome_fh.fetch(self.chrom, *pos) transcript_seqs = {} for i, transcript in trL: - transcript_seq = '' - for exon in transcript['exons']: - transcript_seq += seq[exon[0]-pos[0]:exon[1]-pos[0]] + transcript_seq = "" + for exon in transcript["exons"]: + transcript_seq += seq[exon[0] - pos[0] : exon[1] - pos[0]] transcript_seqs[i] = transcript_seq - if self.strand == '-': - transcript_seqs = {i: reverse_complement(ts) for i, ts in transcript_seqs.items()} + if self.strand == "-": + transcript_seqs = { + i: reverse_complement(ts) for i, ts in transcript_seqs.items() + } if not protein: return transcript_seqs + prot_seqs = {} for i, transcript in trL: orf = transcript.get("CDS", transcript.get("ORF")) if not orf: continue - pos = sorted(self.find_transcript_positions(i, orf[:2], reference=reference)) + pos = sorted( + self.find_transcript_positions(i, orf[:2], reference=reference) + ) try: - prot_seqs[i] = translate(transcript_seqs[i][pos[0]:pos[1]], cds=True) + prot_seqs[i] = translate(transcript_seqs[i][pos[0] : pos[1]], cds=True) except TranslationError: - logger.warning(f'CDS sequence of {self.id} {"reference" if reference else ""} transcript {i} cannot be translated.') + logger.warning( + f'CDS sequence of {self.id} {"reference" if reference else ""} transcript {i} cannot be translated.' + ) return prot_seqs def _get_ref_cds_pos(self, transcript_ids=None): - '''find the position of annotated CDS initiation ''' + """find the position of annotated CDS initiation""" if transcript_ids is None: transcript_ids = range(len(self.transcripts)) - reverse_strand = self.strand == '-' + reverse_strand = self.strand == "-" utr, anno_cds = {}, {} match_cds = {} for i, transcript in enumerate(self.ref_transcripts): - if 'CDS' in transcript: - anno_cds[i] = transcript['CDS'] + if "CDS" in transcript: + anno_cds[i] = transcript["CDS"] if not reverse_strand: - utr[i] = [exon for exon in transcript['exons'] if exon[0] < transcript['CDS'][0]] + utr[i] = [ + exon + for exon in transcript["exons"] + if exon[0] < transcript["CDS"][0] + ] else: - utr[i] = [exon for exon in transcript['exons'] if exon[1] > transcript['CDS'][1]] + utr[i] = [ + exon + for exon in transcript["exons"] + if exon[1] > transcript["CDS"][1] + ] for transcript_id in transcript_ids: transcript = self.transcripts[transcript_id] match_cds[transcript_id] = {} for i, reg in utr.items(): - if not any(start <= anno_cds[i][reverse_strand] <= end for start, end in transcript['exons']): # no overlap of CDS init with exons + if not any( + start <= anno_cds[i][reverse_strand] <= end + for start, end in transcript["exons"] + ): # no overlap of CDS init with exons continue if not reverse_strand: - to_check = zip(pairwise(reg), pairwise(transcript['exons'])) + to_check = zip(pairwise(reg), pairwise(transcript["exons"])) else: - to_check = zip(pairwise((end, start) for start, end in reversed(reg)), pairwise((e, s) for s, e in reversed(transcript['exons']))) - for ((e1reg, e2reg), (e1, e2)) in to_check: - if (e1reg[1] != e1[1] or e2reg[0] != e2[0]): + to_check = zip( + pairwise((end, start) for start, end in reversed(reg)), + pairwise((e, s) for s, e in reversed(transcript["exons"])), + ) + for (e1reg, e2reg), (e1, e2) in to_check: + if e1reg[1] != e1[1] or e2reg[0] != e2[0]: break else: - pos = self.find_transcript_positions(transcript_id, anno_cds[i], reference=False)[reverse_strand] + pos = self.find_transcript_positions( + transcript_id, anno_cds[i], reference=False + )[reverse_strand] match_cds[transcript_id].setdefault(pos, []).append(i) return match_cds - def add_orfs(self, genome_fh, tr_filter={}, reference=False, minlen=300, min_kozak=None, max_5utr_len=0, prefer_annotated_init=True, start_codons=["ATG"], - stop_codons=['TAA', 'TAG', 'TGA'], kozak_matrix=DEFAULT_KOZAK_PWM, get_fickett=True, coding_hexamers=None, noncoding_hexamers=None): - '''Predict the CDS for each transcript. + def add_orfs( + self, + genome_fh, + tr_filter=None, + reference=False, + minlen=300, + min_kozak=None, + max_5utr_len=0, + prefer_annotated_init=True, + start_codons=None, + stop_codons=None, + kozak_matrix=DEFAULT_KOZAK_PWM, + get_fickett=True, + coding_hexamers=None, + noncoding_hexamers=None, + ): + """Predict the CDS for each transcript. For each transcript, one ORF is selected as the coding sequence. Depending on the parameters, this is either the first ORF - (sequence starting with start_condon, and ending with in frame stop codon), or the longest ORF, + (sequence starting with start_condon, and ending with in-frame stop codon), or the longest ORF, starting with a codon that is annotated as CDS initiation site in a reference transcript. - The genomic and transcript positions of these codons, and the length of the ORF, as well as the number of upstream start codons + The genomic and transcript positions of these codons, and the length of the ORF, as well as the number of upstream start codons is added to the transcript properties transcript["ORF"]. Additionally, the Fickett score, and the hexamer score are computed. For the latter, hexamer frequencies in coding and noncoding transcripts are needed. See CPAT python module for prebuilt tables and instructions. + :param tr_filter: dict with filtering parameters passed to iter_transcripts or iter_ref_transcripts :param min_len: Minimum length of the ORF, Does not apply to annotated initiation sites. :param min_kozak: Minimal score for translation initiation site. Does not apply to annotated initiation sites. @@ -425,32 +621,60 @@ def add_orfs(self, genome_fh, tr_filter={}, reference=False, minlen=300, min_koz :param kozak_matrix: A PWM (log odds ratios) to compute the Kozak sequence similarity :param get_fickett: If true, the fickett score for the CDS is computed. :param coding_hexamers: The hexamer frequencies for coding sequences. - :param noncoding_hexamers: The hexamer frequencies for non-coding sequences (background).''' + :param noncoding_hexamers: The hexamer frequencies for non-coding sequences (background). + """ + if tr_filter is None: + tr_filter = {} + if start_codons is None: + start_codons = ["ATG"] + if stop_codons is None: + stop_codons = ["TAA", "TAG", "TGA"] + if tr_filter: if reference: - tr_dict = {i: self.ref_transcripts[i] for i in self.filter_ref_transcripts(**tr_filter)} + tr_dict = { + i: self.ref_transcripts[i] + for i in self.filter_ref_transcripts(**tr_filter) + } else: - tr_dict = {i: self.transcripts[i] for i in self.filter_transcripts(**tr_filter)} + tr_dict = { + i: self.transcripts[i] for i in self.filter_transcripts(**tr_filter) + } else: - tr_dict = {i: transcript for i, transcript in enumerate(self.ref_transcripts if reference else self.transcripts)} - assert min_kozak is None or kozak_matrix is not None, 'Kozak matrix missing for min_kozak' + tr_dict = { + i: transcript + for i, transcript in enumerate( + self.ref_transcripts if reference else self.transcripts + ) + } + assert ( + min_kozak is None or kozak_matrix is not None + ), "Kozak matrix missing for min_kozak" if not tr_dict: return if prefer_annotated_init: if reference: ref_cds = {} for i, transcript in tr_dict.items(): - if 'CDS' not in transcript: + if "CDS" not in transcript: ref_cds[i] = {} else: - pos = self.find_transcript_positions(i, transcript['CDS'], reference=True)[self.strand == '-'] + pos = self.find_transcript_positions( + i, transcript["CDS"], reference=True + )[self.strand == "-"] ref_cds[i] = {pos: [i]} else: ref_cds = self._get_ref_cds_pos(transcript_ids=tr_dict.keys()) - for transcript_id, tr_seq in self.get_sequence(genome_fh, transcript_ids=tr_dict.keys(), reference=reference).items(): + for transcript_id, tr_seq in self.get_sequence( + genome_fh, transcript_ids=tr_dict.keys(), reference=reference + ).items(): orfs = find_orfs( - tr_seq, start_codons, stop_codons, ref_cds[transcript_id] if prefer_annotated_init else []) + tr_seq, + start_codons, + stop_codons, + ref_cds[transcript_id] if prefer_annotated_init else [], + ) if not orfs: # No ORF continue # select best ORF @@ -458,195 +682,281 @@ def add_orfs(self, genome_fh, tr_filter={}, reference=False, minlen=300, min_koz anno_orfs = [orf for orf in orfs if bool(orf[6]) and orf[1] is not None] kozak = None if anno_orfs: - start, stop, frame, seq_start, seq_end, uORFs, ref_transcript_ids = max(anno_orfs, key=lambda x: x[1]-x[0]) + start, stop, frame, seq_start, seq_end, uORFs, ref_transcript_ids = max( + anno_orfs, key=lambda x: x[1] - x[0] + ) else: - valid_orfs = [orf for orf in orfs if orf[1] is not None and orf[1]-orf[0] > minlen and orf[0] <= max_5utr_len] + valid_orfs = [ + orf + for orf in orfs + if orf[1] is not None + and orf[1] - orf[0] > minlen + and orf[0] <= max_5utr_len + ] if not valid_orfs: continue if min_kozak is not None: for orf in valid_orfs: kozak = kozak_score(tr_seq, orf[0], kozak_matrix) if kozak > min_kozak: - start, stop, frame, seq_start, seq_end, uORFs, ref_transcript_ids = orf + ( + start, + stop, + frame, + seq_start, + seq_end, + uORFs, + ref_transcript_ids, + ) = orf break else: continue else: - start, stop, frame, seq_start, seq_end, uORFs, ref_transcript_ids = valid_orfs[0] - - start, stop, frame, seq_start, seq_end, uORFs, ref_transcript_ids = valid_orfs[0] + ( + start, + stop, + frame, + seq_start, + seq_end, + uORFs, + ref_transcript_ids, + ) = valid_orfs[0] + + start, stop, frame, seq_start, seq_end, uORFs, ref_transcript_ids = ( + valid_orfs[0] + ) # if stop is None or stop - start < minlen: # continue transcript = tr_dict[transcript_id] - transcript_start = transcript['exons'][0][0] - cum_exon_len = np.cumsum([end-start for start, end in transcript['exons']]) # cumulative exon length - cum_intron_len = np.cumsum([0]+[end-start for (_, start), (end, _) in pairwise(transcript['exons'])]) # cumulative intron length - if self.strand == '-': - fwd_start, fwd_stop = cum_exon_len[-1]-stop, cum_exon_len[-1]-start + transcript_start = transcript["exons"][0][0] + cum_exon_len = np.cumsum( + [end - start for start, end in transcript["exons"]] + ) # cumulative exon length + cum_intron_len = np.cumsum( + [0] + + [ + end - start + for (_, start), (end, _) in pairwise(transcript["exons"]) + ] + ) # cumulative intron length + if self.strand == "-": + fwd_start, fwd_stop = cum_exon_len[-1] - stop, cum_exon_len[-1] - start else: - fwd_start, fwd_stop = start, stop # start/stop position wrt genomic fwd strand - start_exon = next(i for i in range(len(cum_exon_len)) if cum_exon_len[i] >= fwd_start) - stop_exon = next(i for i in range(start_exon, len(cum_exon_len)) if cum_exon_len[i] >= fwd_stop) - genome_pos = (transcript_start+fwd_start+cum_intron_len[start_exon], - transcript_start+fwd_stop+cum_intron_len[stop_exon]) + fwd_start, fwd_stop = ( + start, + stop, + ) # start/stop position wrt genomic fwd strand + start_exon = next( + i for i in range(len(cum_exon_len)) if cum_exon_len[i] >= fwd_start + ) + stop_exon = next( + i + for i in range(start_exon, len(cum_exon_len)) + if cum_exon_len[i] >= fwd_stop + ) + genome_pos = ( + transcript_start + fwd_start + cum_intron_len[start_exon], + transcript_start + fwd_stop + cum_intron_len[stop_exon], + ) dist_pas = 0 # distance of termination codon to last upstream splice site - if self.strand == '+' and stop_exon < len(cum_exon_len)-1: - dist_pas = cum_exon_len[-2]-fwd_stop - if self.strand == '-' and start_exon > 0: - dist_pas = fwd_start-cum_exon_len[0] - orf_dict = {"5'UTR": start, 'CDS': stop-start, "3'UTR": cum_exon_len[-1]-stop, - 'start_codon': seq_start, 'stop_codon': seq_end, 'NMD': dist_pas > 55, 'uORFs': uORFs, 'ref_ids': ref_transcript_ids} + if self.strand == "+" and stop_exon < len(cum_exon_len) - 1: + dist_pas = cum_exon_len[-2] - fwd_stop + if self.strand == "-" and start_exon > 0: + dist_pas = fwd_start - cum_exon_len[0] + orf_dict = { + "5'UTR": start, + "CDS": stop - start, + "3'UTR": cum_exon_len[-1] - stop, + "start_codon": seq_start, + "stop_codon": seq_end, + "NMD": dist_pas > 55, + "uORFs": uORFs, + "ref_ids": ref_transcript_ids, + } if kozak_matrix is not None: if kozak is None: - orf_dict['kozak'] = kozak_score(tr_seq, start, kozak_matrix) + orf_dict["kozak"] = kozak_score(tr_seq, start, kozak_matrix) else: - orf_dict['kozak'] = kozak + orf_dict["kozak"] = kozak if coding_hexamers is not None and noncoding_hexamers is not None: - orf_dict['hexamer'] = FrameKmer.kmer_ratio(tr_seq[start:stop], 6, 3, coding_hexamers, noncoding_hexamers) + orf_dict["hexamer"] = FrameKmer.kmer_ratio( + tr_seq[start:stop], 6, 3, coding_hexamers, noncoding_hexamers + ) if get_fickett: - orf_dict['fickett'] = fickett.fickett_value(tr_seq[start:stop]) - transcript['ORF'] = (*genome_pos, orf_dict) + orf_dict["fickett"] = fickett.fickett_value(tr_seq[start:stop]) + transcript["ORF"] = (*genome_pos, orf_dict) def add_fragments(self): - '''Checks for transcripts that are fully contained in other transcripts. + """Checks for transcripts that are fully contained in other transcripts. Transcripts that are fully contained in other transcripts are potential truncations. This function populates the 'fragment' field of the transcript dictionaries with the indices of the containing transcripts, - and the exon ids that match the first and last exons.''' + and the exon ids that match the first and last exons.""" for transcript_id, containers in self.segment_graph.find_fragments().items(): - self.transcripts[transcript_id]['fragments'] = containers # list of (containing transcript id, first 5' exons, first 3'exons) + self.transcripts[transcript_id][ + "fragments" + ] = containers # list of (containing transcript id, first 5' exons, first 3'exons) def coding_len(self, transcript_id): - '''Returns length of 5\'UTR, coding sequence and 3\'UTR. + """Returns length of 5\'UTR, coding sequence and 3\'UTR. - :param transcript_id: The transcript index for which the coding length is requested. ''' + :param transcript_id: The transcript index for which the coding length is requested. + """ try: - exons = self.transcripts[transcript_id]['exons'] - cds = self.transcripts[transcript_id]['CDS'] + exons = self.transcripts[transcript_id]["exons"] + cds = self.transcripts[transcript_id]["CDS"] except KeyError: return None else: coding_len = _coding_len(exons, cds) - if self.strand == '-': + if self.strand == "-": coding_len.reverse() return coding_len def get_infos(self, transcript_id, keys, sample_i, group_i, **kwargs): - '''Returns the transcript information specified in "keys" as a list.''' - return [value for k in keys for value in self._get_info(transcript_id, k, sample_i, group_i)] + """Returns the transcript information specified in "keys" as a list.""" + return [ + value + for k in keys + for value in self._get_info(transcript_id, k, sample_i, group_i) + ] def _get_info(self, transcript_id, key, sample_i, group_i, **kwargs): # returns tuples (as some keys return multiple values) - if key == 'length': - return sum((e - b for b, e in self.transcripts[transcript_id]['exons'])), - elif key == 'n_exons': - return len(self.transcripts[transcript_id]['exons']), - elif key == 'exon_starts': - return ','.join(str(e[0]) for e in self.transcripts[transcript_id]['exons']), - elif key == 'exon_ends': - return ','.join(str(e[1]) for e in self.transcripts[transcript_id]['exons']), - elif key == 'annotation': + if key == "length": + return (sum((e - b for b, e in self.transcripts[transcript_id]["exons"])),) + elif key == "n_exons": + return (len(self.transcripts[transcript_id]["exons"]),) + elif key == "exon_starts": + return ( + ",".join(str(e[0]) for e in self.transcripts[transcript_id]["exons"]), + ) + elif key == "exon_ends": + return ( + ",".join(str(e[1]) for e in self.transcripts[transcript_id]["exons"]), + ) + elif key == "annotation": # sel=['sj_i','base_i', 'as'] - if 'annotation' not in self.transcripts[transcript_id]: - return ('NA',) * 2 - nov_class, subcat = self.transcripts[transcript_id]['annotation'] + if "annotation" not in self.transcripts[transcript_id]: + return ("NA",) * 2 + nov_class, subcat = self.transcripts[transcript_id]["annotation"] # subcat_string = ';'.join(k if v is None else '{}:{}'.format(k, v) for k, v in subcat.items()) - return SPLICE_CATEGORY[nov_class], ','.join(subcat) # only the names of the subcategories - elif key == 'coverage': + return SPLICE_CATEGORY[nov_class], ",".join( + subcat + ) # only the names of the subcategories + elif key == "coverage": return self.coverage[sample_i, transcript_id] - elif key == 'tpm': - return self.tpm(kwargs.get('pseudocount', 1))[sample_i, transcript_id] - elif key == 'group_coverage_sum': + elif key == "tpm": + return self.tpm(kwargs.get("pseudocount", 1))[sample_i, transcript_id] + elif key == "group_coverage_sum": return tuple(self.coverage[si, transcript_id].sum() for si in group_i) - elif key == 'group_tpm_mean': - return tuple(self.tpm(kwargs.get('pseudocount', 1))[si, transcript_id].mean() for si in group_i) + elif key == "group_tpm_mean": + return tuple( + self.tpm(kwargs.get("pseudocount", 1))[si, transcript_id].mean() + for si in group_i + ) elif key in self.transcripts[transcript_id]: val = self.transcripts[transcript_id][key] if isinstance(val, Iterable): # iterables get converted to string - return str(val), + return (str(val),) else: - return val, # atomic (e.g. numeric) - return 'NA', + return (val,) # atomic (e.g. numeric) + return ("NA",) def _set_coverage(self, force=False): samples = self._transcriptome.samples cov = np.zeros((len(samples), self.n_transcripts), dtype=int) if not force: # keep the segment graph if no new transcripts - known = self.data.get('coverage', None) + known = self.data.get("coverage", None) if known is not None and known.shape[1] == self.n_transcripts: if known.shape == cov.shape: return - cov[:known.shape[0], :] = known + cov[: known.shape[0], :] = known for i in range(known.shape[0], len(samples)): for j, transcript in enumerate(self.transcripts): - cov[i, j] = transcript['coverage'].get(samples[i], 0) - self.data['coverage'] = cov + cov[i, j] = transcript["coverage"].get(samples[i], 0) + self.data["coverage"] = cov return for i, sample in enumerate(samples): for j, transcript in enumerate(self.transcripts): - cov[i, j] = transcript['coverage'].get(sample, 0) - self.data['coverage'] = cov - self.data['segment_graph'] = None + cov[i, j] = transcript["coverage"].get(sample, 0) + self.data["coverage"] = cov + self.data["segment_graph"] = None def tpm(self, pseudocount=1): - '''Returns the transcripts per million (TPM). - - TPM is returned as a numpy array, with samples in columns and transcript isoforms in the rows.''' - return (self.coverage+pseudocount)/self._transcriptome.sample_table['nonchimeric_reads'].values.reshape(-1, 1)*1e6 + """Returns the transcripts per million (TPM). + + TPM is returned as a numpy array, with samples in columns and transcript isoforms in the rows. + """ + return ( + (self.coverage + pseudocount) + / self._transcriptome.sample_table["nonchimeric_reads"].values.reshape( + -1, 1 + ) + * 1e6 + ) def find_transcript_positions(self, transcript_id, pos, reference=False): - '''Converts genomic positions to positions within the transcript. + """Converts genomic positions to positions within the transcript. :param transcript_id: The transcript id - :param pos: List of sorted genomic positions, for which the transcript positions are computed.''' + :param pos: List of sorted genomic positions, for which the transcript positions are computed. + """ tr_pos = [] - exons = self.ref_transcripts[transcript_id]['exons'] if reference else self.transcripts[transcript_id]['exons'] + exons = ( + self.ref_transcripts[transcript_id]["exons"] + if reference + else self.transcripts[transcript_id]["exons"] + ) e_idx = 0 offset = 0 for p in sorted(pos): try: while p > exons[e_idx][1]: - offset += (exons[e_idx][1]-exons[e_idx][0]) + offset += exons[e_idx][1] - exons[e_idx][0] e_idx += 1 except IndexError: - for _ in range(len(pos)-len(tr_pos)): + for _ in range(len(pos) - len(tr_pos)): tr_pos.append(None) break - tr_pos.append(offset+p-exons[e_idx][0] if p >= exons[e_idx][0] else None) - if self.strand == '-': - trlen = sum(end-start for start, end in exons) - tr_pos = [None if p is None else trlen-p for p in tr_pos] + tr_pos.append( + offset + p - exons[e_idx][0] if p >= exons[e_idx][0] else None + ) + if self.strand == "-": + trlen = sum(end - start for start, end in exons) + tr_pos = [None if p is None else trlen - p for p in tr_pos] return tr_pos @property def coverage(self): - '''Returns the transcript coverage. + """Returns the transcript coverage. - Coverage is returned as a numpy array, with samples in columns and transcript isoforms in the rows.''' - cov = self.data.get('coverage', None) + Coverage is returned as a numpy array, with samples in columns and transcript isoforms in the rows. + """ + cov = self.data.get("coverage", None) if cov is not None: return cov self._set_coverage() - return self.data['coverage'] + return self.data["coverage"] @property def gene_coverage(self): - '''Returns the gene coverage. + """Returns the gene coverage. - Total Coverage of the gene for each sample.''' + Total Coverage of the gene for each sample.""" return self.coverage.sum(1) @property def chrom(self): - '''Returns the genes chromosome.''' - return self.data['chr'] + """Returns the genes chromosome.""" + return self.data["chr"] @property def start(self): # alias for begin @@ -654,120 +964,143 @@ def start(self): # alias for begin @property def region(self): - '''Returns the region of the gene as a string in the format "chr:start-end".''' + """Returns the region of the gene as a string in the format "chr:start-end".""" try: - return '{}:{}-{}'.format(self.chrom, self.start, self.end) + return "{}:{}-{}".format(self.chrom, self.start, self.end) except KeyError: raise @property def id(self): - '''Returns the gene id''' + """Returns the gene id""" try: - return self.data['ID'] + return self.data["ID"] except KeyError: logger.error(self.data) raise @property def name(self): - '''Returns the gene name''' + """Returns the gene name""" try: - return self.data['name'] + return self.data["name"] except KeyError: return self.id # e.g. novel genes do not have a name (but id) @property def is_annotated(self): - '''Returns "True" iff reference annotation is present for the gene.''' - return 'reference' in self.data + """Returns "True" iff reference annotation is present for the gene.""" + return "reference" in self.data @property def is_expressed(self): - '''Returns "True" iff gene is covered by at least one long read in at least one sample.''' + """Returns "True" iff gene is covered by at least one long read in at least one sample.""" return bool(self.transcripts) @property def strand(self): '''Returns the strand of the gene, e.g. "+" or "-"''' - return self.data['strand'] + return self.data["strand"] @property def transcripts(self) -> list[Transcript]: - '''Returns the list of transcripts of the gene, as found by LRTS.''' + """Returns the list of transcripts of the gene, as found by LRTS.""" try: - return self.data['transcripts'] + return self.data["transcripts"] except KeyError: return [] @property def ref_transcripts(self) -> list[Transcript]: - '''Returns the list of reference transcripts of the gene.''' + """Returns the list of reference transcripts of the gene.""" try: - return self.data['reference']['transcripts'] + return self.data["reference"]["transcripts"] except KeyError: return [] @property def n_transcripts(self): - '''Returns number of transcripts of the gene, as found by LRTS.''' + """Returns number of transcripts of the gene, as found by LRTS.""" return len(self.transcripts) @property def n_ref_transcripts(self): - '''Returns number of reference transcripts of the gene.''' + """Returns number of reference transcripts of the gene.""" return len(self.ref_transcripts) @property def ref_segment_graph(self): # raises key error if not self.is_annotated - '''Returns the segment graph of the reference transcripts for the gene''' + """Returns the segment graph of the reference transcripts for the gene""" assert self.is_annotated, "reference segment graph requested on novel gene" - if 'segment_graph' not in self.data['reference'] or self.data['reference']['segment_graph'] is None: - transcript_exons = [transcript['exons'] for transcript in self.ref_transcripts] - self.data['reference']['segment_graph'] = SegmentGraph(transcript_exons, self.strand) - return self.data['reference']['segment_graph'] + if ( + "segment_graph" not in self.data["reference"] + or self.data["reference"]["segment_graph"] is None + ): + transcript_exons = [ + transcript["exons"] for transcript in self.ref_transcripts + ] + self.data["reference"]["segment_graph"] = SegmentGraph( + transcript_exons, self.strand + ) + return self.data["reference"]["segment_graph"] @property def segment_graph(self): - '''Returns the segment graph of the LRTS transcripts for the gene''' - if 'segment_graph' not in self.data or self.data['segment_graph'] is None: - transcript_exons = [transcript['exons'] for transcript in self.transcripts] + """Returns the segment graph of the LRTS transcripts for the gene""" + if "segment_graph" not in self.data or self.data["segment_graph"] is None: + transcript_exons = [transcript["exons"] for transcript in self.transcripts] try: - self.data['segment_graph'] = SegmentGraph(transcript_exons, self.strand) + self.data["segment_graph"] = SegmentGraph(transcript_exons, self.strand) except Exception: - logger.error('Error initializing Segment Graph on %s with exons %s', self.strand, transcript_exons) + logger.error( + "Error initializing Segment Graph on %s with exons %s", + self.strand, + transcript_exons, + ) raise - return self.data['segment_graph'] + return self.data["segment_graph"] def segment_graph_filtered(self, query=None, min_coverage=None, max_coverage=None): - '''Returns a filtered segment graph of the LRTS transcripts for the gene''' + """Returns a filtered segment graph of the LRTS transcripts for the gene""" transcript_ids = self.filter_transcripts(query, min_coverage, max_coverage) - transcript_exons = [transcript['exons'] for i, transcript in enumerate(self.transcripts) if i in transcript_ids] + transcript_exons = [ + transcript["exons"] + for i, transcript in enumerate(self.transcripts) + if i in transcript_ids + ] return SegmentGraph(transcript_exons, self.strand) def __copy__(self): return Gene(self.start, self.end, self.data, self._transcriptome) def __deepcopy__(self, memo): # does not copy _transcriptome! - return Gene(self.start, self.end, copy.deepcopy(self.data, memo), self._transcriptome) + return Gene( + self.start, self.end, copy.deepcopy(self.data, memo), self._transcriptome + ) def __reduce__(self): return Gene, (self.start, self.end, self.data, self._transcriptome) def copy(self): - 'Returns a shallow copy of self.' + "Returns a shallow copy of self." return self.__copy__() def filter_transcripts(self, query=None, min_coverage=None, max_coverage=None): if query: - transcript_filter = self._transcriptome.filter['transcript'] + transcript_filter = self._transcriptome.filter["transcript"] # used_tags={tag for tag in re.findall(r'\b\w+\b', query) if tag not in BOOL_OP} query_fun, used_tags = _filter_function(query) - msg = 'did not find the following filter rules: {}\nvalid rules are: {}' + msg = "did not find the following filter rules: {}\nvalid rules are: {}" assert all(f in transcript_filter for f in used_tags), msg.format( - ', '.join(f for f in used_tags if f not in transcript_filter), ', '.join(transcript_filter)) - transcript_filter_fun = {tag: _filter_function(tag, transcript_filter)[0] for tag in used_tags if tag in transcript_filter} + ", ".join(f for f in used_tags if f not in transcript_filter), + ", ".join(transcript_filter), + ) + transcript_filter_fun = { + tag: _filter_function(tag, transcript_filter)[0] + for tag in used_tags + if tag in transcript_filter + } transcript_ids = [] for i, transcript in enumerate(self.transcripts): if min_coverage and self.coverage[:, i].sum() < min_coverage: @@ -775,39 +1108,54 @@ def filter_transcripts(self, query=None, min_coverage=None, max_coverage=None): if max_coverage and self.coverage[:, i].sum() > max_coverage: continue if query is None or query_fun( - **{tag: f(gene=self, trid=i, **transcript) for tag, f in transcript_filter_fun.items()}): + **{ + tag: f(gene=self, trid=i, **transcript) + for tag, f in transcript_filter_fun.items() + } + ): transcript_ids.append(i) return transcript_ids def filter_ref_transcripts(self, query=None): if query: - transcript_filter = self._transcriptome.filter['reference'] + transcript_filter = self._transcriptome.filter["reference"] # used_tags={tag for tag in re.findall(r'\b\w+\b', query) if tag not in BOOL_OP} query_fun, used_tags = _filter_function(query) - msg = 'did not find the following filter rules: {}\nvalid rules are: {}' + msg = "did not find the following filter rules: {}\nvalid rules are: {}" assert all(f in transcript_filter for f in used_tags), msg.format( - ', '.join(f for f in used_tags if f not in transcript_filter), ', '.join(transcript_filter)) - transcript_filter_func = {tag: _filter_function(transcript_filter[tag])[0] for tag in used_tags if tag in transcript_filter} + ", ".join(f for f in used_tags if f not in transcript_filter), + ", ".join(transcript_filter), + ) + transcript_filter_func = { + tag: _filter_function(transcript_filter[tag])[0] + for tag in used_tags + if tag in transcript_filter + } else: return list(range(len(self.ref_transcripts))) transcript_ids = [] for i, transcript in enumerate(self.ref_transcripts): filter_transcript = transcript.copy() - if query_fun(**{tag: f(gene=self, trid=i, **filter_transcript) for tag, f in transcript_filter_func.items()}): + if query_fun( + **{ + tag: f(gene=self, trid=i, **filter_transcript) + for tag, f in transcript_filter_func.items() + } + ): transcript_ids.append(i) return transcript_ids def _find_splice_sites(exons, transcripts): - '''Checks whether the splice sites of a new transcript are present in the set of transcripts. + """Checks whether the splice sites of a new transcript are present in the set of transcripts. avoids the computation of segment graph, which provides the same functionality. :param exons: A list of exon tuples representing the transcript :type exons: list - :return: boolean array indicating whether the splice site is contained or not''' + :return: boolean array indicating whether the splice site is contained or not""" - intron_iter = [pairwise(transcript['exons']) for transcript in transcripts] + intron_iter = [pairwise(transcript["exons"]) for transcript in transcripts] current = [next(transcript) for transcript in intron_iter] - contained = np.zeros(len(exons)-1) + contained = np.zeros(len(exons) - 1) for j, (exon1, exon2) in enumerate(pairwise(exons)): for i, transcript in enumerate(intron_iter): while current[i][0][1] < exon1[1]: @@ -819,10 +1167,19 @@ def _find_splice_sites(exons, transcripts): contained[j] = True return current - def coordination_test(self, samples=None, test: Literal['fisher', 'chi2'] = "fisher", min_dist_AB=1, min_dist_events=1, min_total=100, min_alt_fraction=.1, - events: Optional[list[ASEvent]] = None, event_type=("ES", "5AS", "3AS", "IR", "ME"), - transcript_filter: Optional[str] = None) -> list[tuple]: - '''Performs pairwise independence test for all pairs of Alternative Splicing Events (ASEs) in a gene. + def coordination_test( + self, + samples=None, + test: Literal["fisher", "chi2"] = "fisher", + min_dist_AB=1, + min_dist_events=1, + min_total=100, + min_alt_fraction=0.1, + events: Optional[list[ASEvent]] = None, + event_type=("ES", "5AS", "3AS", "IR", "ME"), + transcript_filter: Optional[str] = None, + ) -> list[tuple]: + """Performs pairwise independence test for all pairs of Alternative Splicing Events (ASEs) in a gene. For all pairs of ASEs in a gene creates a contingency table and performs an independence test. All ASEs A have two states, pri_A and alt_A, the primary and the alternative state respectivley. @@ -849,7 +1206,7 @@ def coordination_test(self, samples=None, test: Literal['fisher', 'chi2'] = "fis :return: A list of tuples with the test results: (gene_id, gene_name, strand, eventA_type, eventB_type, eventA_start, eventA_end, eventB_start, eventB_end, p_value, test_stat, log2OR, dcPSI_AB, dcPSI_BA, priA_priB, priA_altB, altA_priB, altA_altB, priA_priB_transcript_ids, priA_altB_transcript_ids, altA_priB_transcript_ids, altA_altB_transcript_ids). - ''' + """ if samples is None: cov = self.coverage.sum(axis=0) @@ -860,6 +1217,7 @@ def coordination_test(self, samples=None, test: Literal['fisher', 'chi2'] = "fis except IndexError: # Fall back to looking up the sample indices from isotools._transcriptome_stats import _check_groups + _, _, groups = _check_groups(self._transcriptome, [samples], 1) cov = self.coverage[groups[0]].sum(0) @@ -871,11 +1229,20 @@ def coordination_test(self, samples=None, test: Literal['fisher', 'chi2'] = "fis if events is None: events = sg.find_splice_bubbles(types=event_type) - events = [event for event in events - if _filter_event(cov, event, segment_graph=sg, min_total=min_total, - min_alt_fraction=min_alt_fraction, min_dist_AB=min_dist_AB)] + events = [ + event + for event in events + if _filter_event( + cov, + event, + segment_graph=sg, + min_total=min_total, + min_alt_fraction=min_alt_fraction, + min_dist_AB=min_dist_AB, + ) + ] # make sure its sorted (according to gene strand) - if self.strand == '+': + if self.strand == "+": events.sort(key=itemgetter(2, 3), reverse=False) # sort by starting node else: events.sort(key=itemgetter(3, 2), reverse=True) # reverse sort by end node @@ -884,22 +1251,37 @@ def coordination_test(self, samples=None, test: Literal['fisher', 'chi2'] = "fis for event1, event2 in itertools.combinations(events, 2): if sg.events_dist(event1, event2) < min_dist_events: continue - if (event1[4], event2[4]) == ("TSS", "TSS") or (event1[4], event2[4]) == ("PAS", "PAS"): + if (event1[4], event2[4]) == ("TSS", "TSS") or (event1[4], event2[4]) == ( + "PAS", + "PAS", + ): continue con_tab, tr_ID_tab = prepare_contingency_table(event1, event2, cov) - if con_tab.sum(None) < min_total: # check that the joint occurrence of the two events passes the threshold + if ( + con_tab.sum(None) < min_total + ): # check that the joint occurrence of the two events passes the threshold continue - if min(con_tab.sum(1).min(), con_tab.sum(0).min())/con_tab.sum(None) < min_alt_fraction: + if ( + min(con_tab.sum(1).min(), con_tab.sum(0).min()) / con_tab.sum(None) + < min_alt_fraction + ): continue - test_result = pairwise_event_test(con_tab, test=test) # append to test result + test_result = pairwise_event_test( + con_tab, test=test + ) # append to test result coordinate1 = sg._get_event_coordinate(event1) coordinate2 = sg._get_event_coordinate(event2) - attr = (self.id, self.name, self.strand, event1[4], event2[4]) + \ - coordinate1 + coordinate2 + test_result + \ - tuple(con_tab.flatten()) + tuple(tr_ID_tab.flatten()) + attr = ( + (self.id, self.name, self.strand, event1[4], event2[4]) + + coordinate1 + + coordinate2 + + test_result + + tuple(con_tab.flatten()) + + tuple(tr_ID_tab.flatten()) + ) # event1[4] is the event1 type # coordinate1[0] is the starting coordinate of event 1 @@ -912,14 +1294,15 @@ def coordination_test(self, samples=None, test: Literal['fisher', 'chi2'] = "fis return test_res def die_test(self, groups, min_cov=25, n_isoforms=10): - ''' Reimplementation of the DIE test, suggested by Joglekar et al in Nat Commun 12, 463 (2021): + """Reimplementation of the DIE test, suggested by Joglekar et al in Nat Commun 12, 463 (2021): "A spatially resolved brain region- and cell type-specific isoform atlas of the postnatal mouse brain" Syntax and parameters follow the original implementation in https://github.com/noush-joglekar/scisorseqr/blob/master/inst/RScript/IsoformTest.R :param groups: Define the columns for the groups. :param min_cov: Minimal number of reads per group for the gene. - :param n_isoforms: Number of isoforms to consider in the test for the gene. All additional least expressed isoforms get summarized.''' + :param n_isoforms: Number of isoforms to consider in the test for the gene. All additional least expressed isoforms get summarized. + """ # select the samples and sum the group counts try: # Fast mode when testing several genes @@ -927,6 +1310,7 @@ def die_test(self, groups, min_cov=25, n_isoforms=10): except IndexError: # Fall back to looking up the sample indices from isotools._transcriptome_stats import _check_groups + _, _, groups = _check_groups(self._transcriptome, groups) cov = np.array([self.coverage[grp].sum(0) for grp in groups]).T @@ -934,11 +1318,13 @@ def die_test(self, groups, min_cov=25, n_isoforms=10): return np.nan, np.nan, [] # if there are more than 'numIsoforms' isoforms of the gene, all additional least expressed get summarized. if cov.shape[0] > n_isoforms: - idx = np.argpartition(-cov.sum(1), n_isoforms) # take the n_isoforms most expressed isoforms (random order) + idx = np.argpartition( + -cov.sum(1), n_isoforms + ) # take the n_isoforms most expressed isoforms (random order) additional = cov[idx[n_isoforms:]].sum(0) cov = cov[idx[:n_isoforms]] - cov[n_isoforms-1] += additional - idx[n_isoforms-1] = -1 # this isoform gets all other - I give it index + cov[n_isoforms - 1] += additional + idx[n_isoforms - 1] = -1 # this isoform gets all other - I give it index elif cov.shape[0] < 2: return np.nan, np.nan, [] else: @@ -946,10 +1332,10 @@ def die_test(self, groups, min_cov=25, n_isoforms=10): try: _, pval, _, _ = chi2_contingency(cov) except ValueError: - logger.error(f'chi2_contingency({cov})') + logger.error(f"chi2_contingency({cov})") raise - iso_frac = cov/cov.sum(0) - deltaPI = iso_frac[..., 0]-iso_frac[..., 1] + iso_frac = cov / cov.sum(0) + deltaPI = iso_frac[..., 0] - iso_frac[..., 1] order = np.argsort(deltaPI) pos_idx = [order[-i] for i in range(1, 3) if deltaPI[order[-i]] > 0] neg_idx = [order[i] for i in range(2) if deltaPI[order[i]] < 0] @@ -960,61 +1346,77 @@ def die_test(self, groups, min_cov=25, n_isoforms=10): else: return pval, deltaPI_neg, idx[neg_idx] - def _unify_ends(self, smooth_window=31, rel_prominence=1, search_range: tuple[float, float] = (.1, .9), correct_tss=True): - ''' Find common TSS/PAS for transcripts of the gene''' + def _unify_ends( + self, + smooth_window=31, + rel_prominence=1, + search_range: tuple[float, float] = (0.1, 0.9), + correct_tss=True, + ): + """Find common TSS/PAS for transcripts of the gene""" if not self.transcripts: # nothing to do here return - assert 0 <= search_range[0] <= .5 <= search_range[1] <= 1 + assert 0 <= search_range[0] <= 0.5 <= search_range[1] <= 1 # get gene tss/pas profiles tss: dict[int, int] = {} pas: dict[int, int] = {} - strand = 1 if self.strand == '+' else -1 + strand = 1 if self.strand == "+" else -1 for transcript in self.transcripts: - for sample in transcript['TSS']: - for pos, c in transcript['TSS'][sample].items(): - tss[pos] = tss.get(pos, 0)+c - for sample in transcript['PAS']: - for pos, c in transcript['PAS'][sample].items(): - pas[pos] = pas.get(pos, 0)+c + for sample in transcript["TSS"]: + for pos, c in transcript["TSS"][sample].items(): + tss[pos] = tss.get(pos, 0) + c + for sample in transcript["PAS"]: + for pos, c in transcript["PAS"][sample].items(): + pas[pos] = pas.get(pos, 0) + c tss_pos = [min(tss), max(tss)] - if tss_pos[1]-tss_pos[0] < smooth_window: + if tss_pos[1] - tss_pos[0] < smooth_window: tss_pos[0] = tss_pos[1] - smooth_window + 1 pas_pos = [min(pas), max(pas)] - if pas_pos[1]-pas_pos[0] < smooth_window: + if pas_pos[1] - pas_pos[0] < smooth_window: pas_pos[0] = pas_pos[1] - smooth_window + 1 - tss = [tss.get(pos, 0) for pos in range(tss_pos[0], tss_pos[1]+1)] - pas = [pas.get(pos, 0) for pos in range(pas_pos[0], pas_pos[1]+1)] + tss = [tss.get(pos, 0) for pos in range(tss_pos[0], tss_pos[1] + 1)] + pas = [pas.get(pos, 0) for pos in range(pas_pos[0], pas_pos[1] + 1)] # smooth profiles and find maxima tss_smooth = smooth(np.array(tss), smooth_window) pas_smooth = smooth(np.array(pas), smooth_window) # at least half of smooth_window reads required to call a peak # minimal distance between peaks is > ~ smooth_window # rel_prominence=1 -> smaller peak must have twice the height of valley to call two peaks - tss_peaks, _ = find_peaks(np.log2(tss_smooth+1), prominence=(rel_prominence, None)) - tss_peak_pos: list[int] = tss_peaks+tss_pos[0]-1 - pas_peaks, _ = find_peaks(np.log2(pas_smooth+1), prominence=(rel_prominence, None)) - pas_peak_pos: list[int] = pas_peaks+pas_pos[0]-1 + tss_peaks, _ = find_peaks( + np.log2(tss_smooth + 1), prominence=(rel_prominence, None) + ) + tss_peak_pos: list[int] = tss_peaks + tss_pos[0] - 1 + pas_peaks, _ = find_peaks( + np.log2(pas_smooth + 1), prominence=(rel_prominence, None) + ) + pas_peak_pos: list[int] = pas_peaks + pas_pos[0] - 1 # find transcripts with common first/last splice site first_junction: dict[int, list[int]] = {} last_junction: dict[int, list[int]] = {} for transcript_id, transcript in enumerate(self.transcripts): - first_junction.setdefault(transcript['exons'][0][1], []).append(transcript_id) - last_junction.setdefault(transcript['exons'][-1][0], []).append(transcript_id) + first_junction.setdefault(transcript["exons"][0][1], []).append( + transcript_id + ) + last_junction.setdefault(transcript["exons"][-1][0], []).append( + transcript_id + ) # first / last junction with respect to direction of transcription - if self.strand == '-': + if self.strand == "-": first_junction, last_junction = last_junction, first_junction # for each site, find consistent "peaks" TSS/PAS # if none found use median of all read starts for junction_pos, transcript_ids in first_junction.items(): profile = {} for transcript_id in transcript_ids: - for sample_tss in self.transcripts[transcript_id]['TSS'].values(): + for sample_tss in self.transcripts[transcript_id]["TSS"].values(): for pos, c in sample_tss.items(): profile[pos] = profile.get(pos, 0) + c - quantiles = get_quantiles(sorted(profile.items()), [search_range[0], .5, search_range[1]]) + quantiles = get_quantiles( + sorted(profile.items()), [search_range[0], 0.5, search_range[1]] + ) # one/ several peaks within base range? -> quantify by next read_start # else use median ol_peaks = [p for p in tss_peak_pos if quantiles[0] < p <= quantiles[-1]] @@ -1022,23 +1424,32 @@ def _unify_ends(self, smooth_window=31, rel_prominence=1, search_range: tuple[fl ol_peaks = [quantiles[1]] for transcript_id in transcript_ids: transcript = self.transcripts[transcript_id] - transcript['TSS_unified'] = {} - for sample, sample_tss in transcript['TSS'].items(): + transcript["TSS_unified"] = {} + for sample, sample_tss in transcript["TSS"].items(): tss_unified: dict[int, int] = {} # for each read start position, find closest peak for pos, c in sample_tss.items(): - next_peak = min((p for p in ol_peaks if cmp_dist(junction_pos, p, min_dist=3) == strand), - default=pos, key=lambda x: abs(x-pos)) - tss_unified[next_peak] = tss_unified.get(next_peak, 0)+c - transcript['TSS_unified'][sample] = tss_unified + next_peak = min( + ( + p + for p in ol_peaks + if cmp_dist(junction_pos, p, min_dist=3) == strand + ), + default=pos, + key=lambda x: abs(x - pos), + ) + tss_unified[next_peak] = tss_unified.get(next_peak, 0) + c + transcript["TSS_unified"][sample] = tss_unified # same for PAS for junction_pos, transcript_ids in last_junction.items(): profile = {} for transcript_id in transcript_ids: - for sa_pas in self.transcripts[transcript_id]['PAS'].values(): + for sa_pas in self.transcripts[transcript_id]["PAS"].values(): for pos, c in sa_pas.items(): - profile[pos] = profile.get(pos, 0)+c - quantiles = get_quantiles(sorted(profile.items()), [search_range[0], .5, search_range[1]]) + profile[pos] = profile.get(pos, 0) + c + quantiles = get_quantiles( + sorted(profile.items()), [search_range[0], 0.5, search_range[1]] + ) # one/ several peaks within base range? -> quantify by next read_start # else use median ol_peaks = [p for p in pas_peak_pos if quantiles[0] < p <= quantiles[-1]] @@ -1046,79 +1457,128 @@ def _unify_ends(self, smooth_window=31, rel_prominence=1, search_range: tuple[fl ol_peaks = [quantiles[1]] for transcript_id in transcript_ids: transcript = self.transcripts[transcript_id] - transcript['PAS_unified'] = {} - for sample, sa_pas in transcript['PAS'].items(): + transcript["PAS_unified"] = {} + for sample, sa_pas in transcript["PAS"].items(): pas_unified: dict[int, int] = {} for pos, c in sa_pas.items(): - next_peak = min((p for p in ol_peaks if cmp_dist(p, junction_pos, min_dist=3) == strand), - default=pos, key=lambda x: abs(x-pos)) - pas_unified[next_peak] = pas_unified.get(next_peak, 0)+c - transcript['PAS_unified'][sample] = pas_unified + next_peak = min( + ( + p + for p in ol_peaks + if cmp_dist(p, junction_pos, min_dist=3) == strand + ), + default=pos, + key=lambda x: abs(x - pos), + ) + pas_unified[next_peak] = pas_unified.get(next_peak, 0) + c + transcript["PAS_unified"][sample] = pas_unified for transcript in self.transcripts: # find the most common tss/pas per transcript, and set the exon boundaries sum_tss: dict[int, int] = {} sum_pas: dict[int, int] = {} start = end = max_tss = max_pas = 0 - for sample_tss in transcript['TSS_unified'].values(): + for sample_tss in transcript["TSS_unified"].values(): for pos, cov in sample_tss.items(): - sum_tss[pos] = sum_tss.get(pos, 0)+cov + sum_tss[pos] = sum_tss.get(pos, 0) + cov for pos, cov in sum_tss.items(): if cov > max_tss: max_tss = cov start = pos - for sa_pas in transcript['PAS_unified'].values(): + for sa_pas in transcript["PAS_unified"].values(): for pos, cov in sa_pas.items(): sum_pas[pos] = sum_pas.get(pos, 0) + cov for pos, cov in sum_pas.items(): if cov > max_pas: max_pas = cov end = pos - if self.strand == '-': + if self.strand == "-": start, end = end, start if start >= end: # for monoexons this may happen in rare situations - assert len(transcript['exons']) == 1 - transcript['TSS_unified'] = None - transcript['PAS_unified'] = None + assert len(transcript["exons"]) == 1 + transcript["TSS_unified"] = None + transcript["PAS_unified"] = None else: try: # issues if the new exon start is behind the exon end - assert start < transcript['exons'][0][1] or len(transcript['exons']) == 1, 'error unifying %s: %s>=%s' % (transcript["exons"], start, transcript['exons'][0][1]) - transcript['exons'][0][0] = start - assert end > transcript['exons'][-1][0] or len(transcript['exons']) == 1, 'error unifying %s: %s<=%s' % (transcript["exons"], end, transcript['exons'][-1][0]) - transcript['exons'][-1][1] = end + assert ( + start < transcript["exons"][0][1] + or len(transcript["exons"]) == 1 + ), "error unifying %s: %s>=%s" % ( + transcript["exons"], + start, + transcript["exons"][0][1], + ) + transcript["exons"][0][0] = start + assert ( + end > transcript["exons"][-1][0] + or len(transcript["exons"]) == 1 + ), "error unifying %s: %s<=%s" % ( + transcript["exons"], + end, + transcript["exons"][-1][0], + ) + transcript["exons"][-1][1] = end except AssertionError: - logger.error('%s TSS= %s, PAS=%s -> TSS_unified= %s, PAS_unified=%s', self, transcript['TSS'], transcript['PAS'], transcript['TSS_unified'], transcript['PAS_unified']) + logger.error( + "%s TSS= %s, PAS=%s -> TSS_unified= %s, PAS_unified=%s", + self, + transcript["TSS"], + transcript["PAS"], + transcript["TSS_unified"], + transcript["PAS_unified"], + ) raise if correct_tss: self._TSS_correction(transcript) - def _TSS_correction(self, transcript: Transcript): - ''' + """ Correct TSS to the closest upstream reference TSS from best peak. Don't extend past exon ends, because that can introduce artificial intron retention events. - ''' + """ if self.is_annotated: - if transcript['strand'] == '+': - tss = transcript['exons'][0][0] - ref_tsss = [ref_transcript['exons'][0][0] for ref_transcript in self.ref_transcripts if ref_transcript['exons'][0][0] <= tss] + if transcript["strand"] == "+": + tss = transcript["exons"][0][0] + ref_tsss = [ + ref_transcript["exons"][0][0] + for ref_transcript in self.ref_transcripts + if ref_transcript["exons"][0][0] <= tss + ] if ref_tsss: new_tss = max(ref_tsss) # Find ref upstream exon ends that are between the old and the new TSS - ref_exon_ends = [exon[1] for transcript in self.ref_transcripts for exon in transcript["exons"] if exon[1] <= tss and exon[1] >= new_tss] + ref_exon_ends = [ + exon[1] + for transcript in self.ref_transcripts + for exon in transcript["exons"] + if exon[1] <= tss and exon[1] >= new_tss + ] # Don't extend past exon ends if not ref_exon_ends: - logger.debug(f'Corrected TSS ({transcript["strand"]} strand) from {tss} to {new_tss}') - transcript['exons'][0][0] = new_tss + logger.debug( + f'Corrected TSS ({transcript["strand"]} strand) from {tss} to {new_tss}' + ) + transcript["exons"][0][0] = new_tss else: - tss = transcript['exons'][-1][1] - ref_tsss = [ref_transcript['exons'][-1][1] for ref_transcript in self.ref_transcripts if ref_transcript['exons'][-1][1] >= tss] + tss = transcript["exons"][-1][1] + ref_tsss = [ + ref_transcript["exons"][-1][1] + for ref_transcript in self.ref_transcripts + if ref_transcript["exons"][-1][1] >= tss + ] if ref_tsss: new_tss = min(ref_tsss) - ref_exon_ends = [exon[0] for transcript in self.ref_transcripts for exon in transcript["exons"] if exon[0] >= tss and exon[0] <= new_tss] + ref_exon_ends = [ + exon[0] + for transcript in self.ref_transcripts + for exon in transcript["exons"] + if exon[0] >= tss and exon[0] <= new_tss + ] if not ref_exon_ends: - logger.debug(f'Corrected TSS ({transcript["strand"]} strand) from {tss} to {new_tss}') - transcript['exons'][-1][1] = new_tss + logger.debug( + f'Corrected TSS ({transcript["strand"]} strand) from {tss} to {new_tss}' + ) + transcript["exons"][-1][1] = new_tss def _coding_len(exons: list[tuple[int, int]], cds): @@ -1127,7 +1587,9 @@ def _coding_len(exons: list[tuple[int, int]], cds): for exon in exons: if state < 2 and exon[1] >= cds[state]: coding_len[state] += cds[state] - exon[0] - if state == 0 and cds[1] <= exon[1]: # special case: CDS start and end in same exon + if ( + state == 0 and cds[1] <= exon[1] + ): # special case: CDS start and end in same exon coding_len[1] = cds[1] - cds[0] coding_len[2] = exon[1] - cds[1] state += 2 @@ -1140,33 +1602,33 @@ def _coding_len(exons: list[tuple[int, int]], cds): def repeat_len(seq1, seq2, wobble, max_mm): - ''' Calcluate direct repeat length between seq1 and seq2 - ''' - score = [0]*(2*wobble+1) - delta = int(len(seq1)/2-wobble) - for w in range(2*wobble+1): # wobble - s1 = seq1[w:len(seq1)-(2*wobble-w)] - s2 = seq2[wobble:len(seq2)-wobble] + """Calcluate direct repeat length between seq1 and seq2""" + score = [0] * (2 * wobble + 1) + delta = int(len(seq1) / 2 - wobble) + for w in range(2 * wobble + 1): # wobble + s1 = seq1[w : len(seq1) - (2 * wobble - w)] + s2 = seq2[wobble : len(seq2) - wobble] align = [a == b for a, b in zip(s1, s2)] score_left = find_runlength(reversed(align[:delta]), max_mm) score_right = find_runlength(align[delta:], max_mm) - score[w] = max([score_left[fmm]+score_right[max_mm-fmm] for fmm in range(max_mm+1)]) + score[w] = max( + [score_left[fmm] + score_right[max_mm - fmm] for fmm in range(max_mm + 1)] + ) return max(score) def find_runlength(align, max_mm): - '''Find the runlength, e.g. the number of True in the list before the max_mm+1 False occur. - ''' - score = [0]*(max_mm+1) + """Find the runlength, e.g. the number of True in the list before the max_mm+1 False occur.""" + score = [0] * (max_mm + 1) mm = 0 for a in align: if not a: mm += 1 if mm > max_mm: return score - score[mm] = score[mm-1] + score[mm] = score[mm - 1] else: score[mm] += 1 - for i in range(mm+1, max_mm+1): - score[i] = score[i-1] + for i in range(mm + 1, max_mm + 1): + score[i] = score[i - 1] return score diff --git a/src/isotools/plots.py b/src/isotools/plots.py index 9e5a9cb..5f01c42 100644 --- a/src/isotools/plots.py +++ b/src/isotools/plots.py @@ -6,12 +6,24 @@ import numpy as np import ternary import logging -logger = logging.getLogger('isotools') +logger = logging.getLogger("isotools") -def plot_diff_results(result_table, min_support=3, min_diff=.1, grid_shape=(5, 5), min_cov=10, splice_types=None, - group_colors=None, sample_colors=None, pt_size=20, lw=1, ls='solid'): - '''Plots differential splicing results. + +def plot_diff_results( + result_table, + min_support=3, + min_diff=0.1, + grid_shape=(5, 5), + min_cov=10, + splice_types=None, + group_colors=None, + sample_colors=None, + pt_size=20, + lw=1, + ls="solid", +): + """Plots differential splicing results. For the first (e.g. most significant) differential splicing events from result_table that pass the checks defined by the parameters, @@ -32,7 +44,7 @@ def plot_diff_results(result_table, min_support=3, min_diff=.1, grid_shape=(5, 5 :param lw: Specify witdh of the lines. See matplotlib Line2D for details. :param ls: Specify style of the lines. See matplotlib Line2D for details. :return: figure, axes and list of plotted events - ''' + """ plotted = {} # pd.DataFrame(columns=result_table.columns) if isinstance(splice_types, str): @@ -40,43 +52,87 @@ def plot_diff_results(result_table, min_support=3, min_diff=.1, grid_shape=(5, 5 f, axs = plt.subplots(*grid_shape) axs = axs.flatten() x = [i / 100 for i in range(101)] - group_names = [col[:-4] for col in result_table.columns if col.endswith('_PSI')][:2] - groups = {group_name: [col[:col.rfind(group_name)-1] for col in result_table.columns if col.endswith(group_name + '_total_cov')] for group_name in group_names} + group_names = [col[:-4] for col in result_table.columns if col.endswith("_PSI")][:2] + groups = { + group_name: [ + col[: col.rfind(group_name) - 1] + for col in result_table.columns + if col.endswith(group_name + "_total_cov") + ] + for group_name in group_names + } if group_colors is None: - group_colors = ['C0', 'C1'] + group_colors = ["C0", "C1"] if isinstance(group_colors, list): group_colors = dict(zip(group_names, group_colors)) if sample_colors is None: sample_colors = {} - sample_colors = {sample: sample_colors.get(sample, group_colors[name]) for name in group_names for sample in groups[name]} + sample_colors = { + sample: sample_colors.get(sample, group_colors[name]) + for name in group_names + for sample in groups[name] + } other = {group_names[0]: group_names[1], group_names[1]: group_names[0]} - logger.debug('groups: %s', str(groups)) + logger.debug("groups: %s", str(groups)) for idx, row in result_table.iterrows(): - logger.debug('plotting %s: %s', idx, row.gene) + logger.debug("plotting %s: %s", idx, row.gene) if splice_types is not None and row.splice_type not in splice_types: continue if row.gene in plotted: continue - params_alt = {group_name: (row[f'{group_name}_PSI'], row[f'{group_name}_disp']) for group_name in group_names} + params_alt = { + group_name: (row[f"{group_name}_PSI"], row[f"{group_name}_disp"]) + for group_name in group_names + } # select only samples covered >= min_cov - # psi_gr = {groupname: [row[f'{sample}_in_cov'] / row[f'{sample}_total_cov'] for sample in group if row[f'{sample}_total_cov'] >= min_cov] for groupname, group in groups.items()} - psi_gr_list = [(sample, groupname, row[f'{sample}_{groupname}_in_cov'] / row[f'{sample}_{groupname}_total_cov']) - for groupname, group in groups.items() for sample in group if row[f'{sample}_{groupname}_total_cov'] >= min_cov] - psi_gr = pd.DataFrame(psi_gr_list, columns=['sample', 'group', 'psi']) - psi_gr['support'] = [abs(sample.psi - params_alt[sample['group']][0]) < abs(sample.psi - params_alt[other[sample['group']]][0]) for i, sample in psi_gr.iterrows()] - support = dict(psi_gr.groupby('group')['support'].sum()) + # psi_gr = {groupname: [row[f'{sample}_in_cov'] / row[f'{sample}_total_cov'] + # for sample in group if row[f'{sample}_total_cov'] >= min_cov] for groupname, group in groups.items()} + psi_gr_list = [ + ( + sample, + groupname, + row[f"{sample}_{groupname}_in_cov"] + / row[f"{sample}_{groupname}_total_cov"], + ) + for groupname, group in groups.items() + for sample in group + if row[f"{sample}_{groupname}_total_cov"] >= min_cov + ] + psi_gr = pd.DataFrame(psi_gr_list, columns=["sample", "group", "psi"]) + psi_gr["support"] = [ + abs(sample.psi - params_alt[sample["group"]][0]) + < abs(sample.psi - params_alt[other[sample["group"]]][0]) + for i, sample in psi_gr.iterrows() + ] + support = dict(psi_gr.groupby("group")["support"].sum()) if any(sup < min_support for sup in support.values()): - logger.debug('skipping %s with %s supporters', row.gene, support) + logger.debug("skipping %s with %s supporters", row.gene, support) continue - if abs(params_alt[group_names[0]][0] - params_alt[group_names[1]][0]) < min_diff: - logger.debug('%s with %s', row.gene, "vs".join(str(p[0]) for p in params_alt.values())) + if ( + abs(params_alt[group_names[0]][0] - params_alt[group_names[1]][0]) + < min_diff + ): + logger.debug( + "%s with %s", + row.gene, + "vs".join(str(p[0]) for p in params_alt.values()), + ) continue # get the paramters for the beta distiribution ax = axs[len(plotted)] # ax.boxplot([mut,wt], labels=['mut','wt']) - sns.swarmplot(data=psi_gr, x='psi', y='group', hue='sample', orient='h', size=np.sqrt(pt_size), palette=sample_colors, ax=ax) + sns.swarmplot( + data=psi_gr, + x="psi", + y="group", + hue="sample", + orient="h", + size=np.sqrt(pt_size), + palette=sample_colors, + ax=ax, + ) ax.legend([], [], frameon=False) - for i, group_name in enumerate(group_names): + for _, group_name in enumerate(group_names): max_i = int(params_alt[group_name][0] * (len(x) - 1)) ax2 = ax.twinx() # instantiate a second axes that shares the same x-axis if params_alt[group_name][1] > 0: @@ -89,17 +145,30 @@ def plot_diff_results(result_table, min_support=3, min_diff=.1, grid_shape=(5, 5 y[max_i] = 1 # point mass ax2.plot(x, y, color=group_colors[group_name], lw=lw, ls=ls) ax2.tick_params(right=False, labelright=False) - ax.set_title(f'{row.gene} {row.splice_type}\nFDR={row.padj:.5f}') + ax.set_title(f"{row.gene} {row.splice_type}\nFDR={row.padj:.5f}") plotted[row.gene] = row if len(plotted) == len(axs): break return f, axs, pd.concat(plotted.values()) -def plot_embedding(splice_bubbles, method='PCA', prior_count=3, - top_var=500, min_total=100, min_alt_fraction=.1, plot_components=(1, 2), - splice_types='all', labels=True, groups=None, colors=None, pt_size=20, ax=None, **kwargs): - ''' Plots embedding of alternative splicing events. +def plot_embedding( + splice_bubbles, + method="PCA", + prior_count=3, + top_var=500, + min_total=100, + min_alt_fraction=0.1, + plot_components=(1, 2), + splice_types="all", + labels=True, + groups=None, + colors=None, + pt_size=20, + ax=None, + **kwargs, +): + """Plots embedding of alternative splicing events. Alternative splicing events are soreted by variance and only the top variable events are used for the embedding. A prior weight is added to all samples proportional to the average fraction of the alternatives, @@ -120,10 +189,11 @@ def plot_embedding(splice_bubbles, method='PCA', prior_count=3, :param pt_size: Specify the size for the data points in the plot. :param ax: The axis for plotting. :param \\**kwargs: Additional keyword parameters are passed to PCA() or UMAP(). - :return: A dataframe with the proportions of the alternative events, the transformed data and the embedding object.''' + :return: A dataframe with the proportions of the alternative events, the transformed data and the embedding object. + """ - assert method in ['PCA', 'UMAP'], 'method must be PCA or UMAP' - if method == 'UMAP': + assert method in ["PCA", "UMAP"], "method must be PCA or UMAP" + if method == "UMAP": # umap import takes ~15 seconds, hence the lazy import here from umap import UMAP as Embedding # pylint: disable-msg=E0611 else: @@ -132,39 +202,54 @@ def plot_embedding(splice_bubbles, method='PCA', prior_count=3, plot_components = np.array(plot_components) if isinstance(splice_types, str): splice_types = [splice_types] - if 'all' not in splice_types: - splice_bubbles = splice_bubbles.loc[splice_bubbles['splice_type'].isin(splice_types)] - k = splice_bubbles[[c for c in splice_bubbles.columns if c.endswith('_in_cov')]] - n = splice_bubbles[[c for c in splice_bubbles.columns if c.endswith('_total_cov')]] + if "all" not in splice_types: + splice_bubbles = splice_bubbles.loc[ + splice_bubbles["splice_type"].isin(splice_types) + ] + k = splice_bubbles[[c for c in splice_bubbles.columns if c.endswith("_in_cov")]] + n = splice_bubbles[[c for c in splice_bubbles.columns if c.endswith("_total_cov")]] n.columns = [c[:-10] for c in n.columns] k.columns = [c[:-7] for c in k.columns] samples = list(n.columns) - assert all(c1 == c2 for c1, c2 in zip(n.columns, k.columns)), 'issue with sample naming of splice bubble table' + assert all( + c1 == c2 for c1, c2 in zip(n.columns, k.columns) + ), "issue with sample naming of splice bubble table" # select samples and assing colors if groups is None: - groups = {'all samples': samples} + groups = {"all samples": samples} else: - sa_group = {sample: groupname for groupname, sample_list in groups.items() for sample in sample_list if sample in samples} + sa_group = { + sample: groupname + for groupname, sample_list in groups.items() + for sample in sample_list + if sample in samples + } if len(samples) > len(sa_group): samples = [sample for sample in samples if sample in sa_group] - logger.info('restricting embedding on samples ' + ', '.join(samples)) + logger.info("restricting embedding on samples " + ", ".join(samples)) n = n[samples] k = k[samples] if colors is None: - cm = plt.get_cmap('gist_rainbow') + cm = plt.get_cmap("gist_rainbow") colors = {gn: to_hex(cm(i / len(groups))) for i, gn in enumerate(groups)} elif isinstance(colors, dict): - assert all(gn in colors for gn in groups), 'not all groups have colors' - assert all(is_color_like(c) for c in colors.values()), 'invalid colors' + assert all(gn in colors for gn in groups), "not all groups have colors" + assert all(is_color_like(c) for c in colors.values()), "invalid colors" elif len(colors) >= len(groups): - assert all(is_color_like(c) for c in colors), 'invalid colors' + assert all(is_color_like(c) for c in colors), "invalid colors" colors = {gn: colors[i] for i, gn in enumerate(groups)} else: - raise ValueError(f'number of colors ({len(colors)}) does not match number of groups ({len(groups)})') + raise ValueError( + f"number of colors ({len(colors)}) does not match number of groups ({len(groups)})" + ) nsum = n.sum(1) ksum = k.sum(1) - covered = (nsum >= min_total) & (min_alt_fraction < ksum / nsum) & (ksum / nsum < 1 - min_alt_fraction) + covered = ( + (nsum >= min_total) + & (min_alt_fraction < ksum / nsum) + & (ksum / nsum < 1 - min_alt_fraction) + ) n = n.loc[covered] k = k.loc[covered] # compute the proportions @@ -173,16 +258,22 @@ def plot_embedding(splice_bubbles, method='PCA', prior_count=3, topvar = p[:, p.var(0).argsort()[-top_var:]] # sort from low to high var # compute embedding - kwargs.setdefault('n_components', max(plot_components)) - assert kwargs['n_components'] >= max(plot_components), 'n_components is smaller than the largest selected component' + kwargs.setdefault("n_components", max(plot_components)) + assert kwargs["n_components"] >= max( + plot_components + ), "n_components is smaller than the largest selected component" # Linear dimensionality reduction using Singular Value Decomposition of the data to project it to a lower dimensional space. # The input data is centered but not scaled for each feature before applying the SVD. embedding = Embedding(**kwargs).fit(topvar) axparams = dict(title=f'{method} ({",".join(splice_types)})') - if method == 'PCA': - axparams['xlabel'] = f'PC{plot_components[0]} ({embedding.explained_variance_ratio_[plot_components[0]-1]*100:.2f} %)' - axparams['ylabel'] = f'PC{plot_components[1]} ({embedding.explained_variance_ratio_[plot_components[1]-1]*100:.2f} %)' + if method == "PCA": + axparams["xlabel"] = ( + f"PC{plot_components[0]} ({embedding.explained_variance_ratio_[plot_components[0]-1]*100:.2f} %)" + ) + axparams["ylabel"] = ( + f"PC{plot_components[1]} ({embedding.explained_variance_ratio_[plot_components[1]-1]*100:.2f} %)" + ) transformed = pd.DataFrame(embedding.transform(topvar), index=samples) if ax is None: @@ -191,18 +282,32 @@ def plot_embedding(splice_bubbles, method='PCA', prior_count=3, ax.scatter( transformed.loc[sample, plot_components[0] - 1], transformed.loc[sample, plot_components[1] - 1], - c=colors[group], label=group, s=pt_size) + c=colors[group], + label=group, + s=pt_size, + ) ax.set(**axparams) if labels: for idx, (x, y) in transformed[plot_components - 1].iterrows(): ax.text(x, y, s=idx) return pd.DataFrame(p.T, columns=samples, index=k.index), transformed, embedding + # plots -def plot_bar(df, ax=None, drop_categories=None, legend=True, annotate=True, rot=90, bar_width=.5, colors=None, **axparams): - '''Depicts data as a barplot. +def plot_bar( + df, + ax=None, + drop_categories=None, + legend=True, + annotate=True, + rot=90, + bar_width=0.5, + colors=None, + **axparams, +): + """Depicts data as a barplot. This function is intended to be called with the result from isoseq.Transcriptome.filter_stats() or isoseq.Transcriptome.altsplice_stats(). @@ -215,23 +320,27 @@ def plot_bar(df, ax=None, drop_categories=None, legend=True, annotate=True, rot= :param rot: Set rotation of the lables. :param bar_width: Set relative width of the plotted bars. :param colors: Provide a dictionary with label keys and color values. By default, colors are automatically assigned. - :param \\**axparams: Additional keyword parameters are passed to ax.set().''' + :param \\**axparams: Additional keyword parameters are passed to ax.set().""" if ax is None: _, ax = plt.subplots() - if 'total' in df.index: - total = df.loc['total'] - df = df.drop('total') + if "total" in df.index: + total = df.loc["total"] + df = df.drop("total") else: total = df.sum() - fractions = (df / total * 100) + fractions = df / total * 100 if drop_categories is None: dcat = [] else: dcat = [d for d in drop_categories if d in df.index] if colors is None: - colors = [f'C{i}' for i in range(len(df.index)-len(dcat))] # plot.bar cannot deal with color=None - fractions.drop(dcat).plot.bar(ax=ax, legend=legend, width=bar_width, rot=rot, color=colors) + colors = [ + f"C{i}" for i in range(len(df.index) - len(dcat)) + ] # plot.bar cannot deal with color=None + fractions.drop(dcat).plot.bar( + ax=ax, legend=legend, width=bar_width, rot=rot, color=colors + ) # add numbers if annotate: numbers = [int(v) for c in df.drop(dcat).T.values for v in c] @@ -239,16 +348,34 @@ def plot_bar(df, ax=None, drop_categories=None, legend=True, annotate=True, rot= for n, f, p in zip(numbers, frac, ax.patches): small = f < max(frac) / 2 # contrast=tuple(1-cv for cv in p.get_facecolor()[:3]) - contrast = 'white' if np.mean(p.get_facecolor()[:3]) < .5 else 'black' - ax.annotate(f' {f/100:.2%} ({n}) ', (p.get_x() + p.get_width() / 2, p.get_height()), ha='center', - va='bottom' if small else 'top', rotation=90, color='black' if small else contrast, fontweight='bold') + contrast = "white" if np.mean(p.get_facecolor()[:3]) < 0.5 else "black" + ax.annotate( + f" {f/100:.2%} ({n}) ", + (p.get_x() + p.get_width() / 2, p.get_height()), + ha="center", + va="bottom" if small else "top", + rotation=90, + color="black" if small else contrast, + fontweight="bold", + ) ax.set(**axparams) return ax -def plot_distr(counts, ax=None, density=False, smooth=None, legend=True, fill=True, lw=1, ls='solid', colors=None, **axparams): - '''Depicts data as density plot. +def plot_distr( + counts, + ax=None, + density=False, + smooth=None, + legend=True, + fill=True, + lw=1, + ls="solid", + colors=None, + **axparams, +): + """Depicts data as density plot. This function is intended to be called with the result from isoseq.Transcriptome.transcript_length_hist(), isoseq.Transcriptome.transcripts_per_gene_hist(), @@ -264,7 +391,7 @@ def plot_distr(counts, ax=None, density=False, smooth=None, legend=True, fill=Tr :param lw: Specify witdh of the lines. See matplotlib Line2D for details. :param ls: Specify style of the lines. See matplotlib Line2D for details. :param colors: Provide a dictionary with label keys and color values. By default, colors are automatically assigned. - :param \\**axparams: Additional keyword parameters are passed to ax.set().''' + :param \\**axparams: Additional keyword parameters are passed to ax.set().""" # maybe add smoothing x = [sum(bin) / 2 for bin in counts.index] @@ -274,19 +401,19 @@ def plot_distr(counts, ax=None, density=False, smooth=None, legend=True, fill=Tr if ax is None: _, ax = plt.subplots() if density: - counts = (counts / counts.sum()) - if 'ylabel' in axparams and 'density' not in axparams['ylabel']: - axparams['ylabel'] += ' density' + counts = counts / counts.sum() + if "ylabel" in axparams and "density" not in axparams["ylabel"]: + axparams["ylabel"] += " density" else: - axparams['ylabel'] = 'density' + axparams["ylabel"] = "density" else: - axparams.setdefault('ylabel', '# transcripts') + axparams.setdefault("ylabel", "# transcripts") if smooth: counts = counts.ewm(span=smooth).mean() for gn, gc in counts.items(): lines = ax.plot(x, gc / sz, label=gn, color=colors.get(gn, None), lw=lw, ls=ls) if fill: - ax.fill_between(x, 0, gc / sz, alpha=.5, color=lines[-1].get_color()) + ax.fill_between(x, 0, gc / sz, alpha=0.5, color=lines[-1].get_color()) # ax.plot(x, counts.divide(sz, axis=0)) ax.set(**axparams) if legend: @@ -294,8 +421,17 @@ def plot_distr(counts, ax=None, density=False, smooth=None, legend=True, fill=Tr return ax -def plot_saturation(isoseq=None, ax=None, cov_th=2, expr_th=[.5, 1, 2, 5, 10], x_range=(1e4, 1e7, 1e4), legend=True, label=True, **axparams): - '''Plots Negative Binomial model to analyze the saturation of LRTS data. +def plot_saturation( + isoseq=None, + ax=None, + cov_th=2, + expr_th=None, + x_range=(1e4, 1e7, 1e4), + legend=True, + label=True, + **axparams, +): + """Plots Negative Binomial model to analyze the saturation of LRTS data. Saturation (e.g. the probability to observe a transcript of interest in the sample) is dependent on the sequencing depth (number of reads), the concentration of the transcripts of interest in the sample (in TPM), @@ -309,22 +445,42 @@ def plot_saturation(isoseq=None, ax=None, cov_th=2, expr_th=[.5, 1, 2, 5, 10], x :param x_range: Specify the range of the x axis (e.g. the sequencing depth) :param legend: If set True, a legend is added to the plot. :param label: If set True, the sample names and sequencing depth from the isoseq parameter is printed in the plot. - :param \\**axparams: Additional keyword parameters are passed to ax.set().''' + :param \\**axparams: Additional keyword parameters are passed to ax.set(). + """ + if expr_th is None: + expr_th = [0.5, 1, 2, 5, 10] + if ax is None: _, ax = plt.subplots() k = np.arange(*x_range) - axparams.setdefault('title', 'Saturation Analysis') # [nr],{'fontsize':20}, loc='left', pad=10) - axparams.setdefault('ylabel', (f'probaility of sampling at least {cov_th} transcript{"s" if cov_th>1 else ""}')) - axparams.setdefault('ylim', (0, 1)) - axparams.setdefault('xlabel', 'number of reads [million]') - n_reads = isoseq.sample_table.set_index('name')['nonchimeric_reads'] if isoseq is not None else {} + axparams.setdefault( + "title", "Saturation Analysis" + ) # [nr],{'fontsize':20}, loc='left', pad=10) + axparams.setdefault( + "ylabel", + f"Probability of sampling at least {cov_th} transcript{'s' if cov_th > 1 else ''}", + ) + axparams.setdefault("ylim", (0, 1)) + axparams.setdefault("xlabel", "number of reads [million]") + n_reads = ( + isoseq.sample_table.set_index("name")["nonchimeric_reads"] + if isoseq is not None + else {} + ) for tpm_th in expr_th: - chance = nbinom.cdf(k - cov_th, n=cov_th, p=tpm_th * 1e-6) # 0 to k-cov_th failiors - ax.plot(k / 1e6, chance, label=f'{tpm_th} TPM') + chance = nbinom.cdf( + k - cov_th, n=cov_th, p=tpm_th * 1e-6 + ) # 0 to k-cov_th failiors + ax.plot(k / 1e6, chance, label=f"{tpm_th} TPM") for sample, cov in n_reads.items(): - ax.axvline(cov / 1e6, color='grey', ls='--') + ax.axvline(cov / 1e6, color="grey", ls="--") if label: - ax.text((cov + (k[-1] - k[0]) / 200) / 1e6, 0.1, f'{sample} ({cov/1e6:.2f} M)', rotation=-90) + ax.text( + (cov + (k[-1] - k[0]) / 200) / 1e6, + 0.1, + f"{sample} ({cov/1e6:.2f} M)", + rotation=-90, + ) ax.set(**axparams) if legend: @@ -332,8 +488,17 @@ def plot_saturation(isoseq=None, ax=None, cov_th=2, expr_th=[.5, 1, 2, 5, 10], x return ax -def plot_rarefaction(rarefaction, total=None, ax=None, legend=True, colors=None, lw=1, ls='solid', **axparams): - '''Plots the rarefaction curve. +def plot_rarefaction( + rarefaction, + total=None, + ax=None, + legend=True, + colors=None, + lw=1, + ls="solid", + **axparams, +): + """Plots the rarefaction curve. :param rarefaction: A DataFrame with the observed number of transcripts, as computed by Transcriptome.rarefaction(). :param total: A dictionary with the total number of reads per sample/sample group, as computed by Transcriptome.rarefaction(). @@ -342,65 +507,124 @@ def plot_rarefaction(rarefaction, total=None, ax=None, legend=True, colors=None :param colors: Provide a dictionary with label keys and color values. By default, colors are automatically assigned. :param lw: Specify witdh of the lines. See matplotlib Line2D for details. :param ls: Specify style of the lines. See matplotlib Line2D for details. - :param \\**axparams: Additional keyword parameters are passed to ax.set().''' + :param \\**axparams: Additional keyword parameters are passed to ax.set().""" if ax is None: _, ax = plt.subplots() if colors is None: colors = {} for sample in rarefaction.columns: - ax.plot([float(f) * total[sample] / 1e6 if total is not None else float(f)*100 for f in rarefaction.index], rarefaction[sample], - label=sample, ls=ls, lw=lw, color=colors.get(sample, None)) - - axparams.setdefault('title', 'Rarefaction Analysis') # [nr],{'fontsize':20}, loc='left', pad=10) - axparams.setdefault('ylabel', 'Number of discovered Transcripts') - axparams.setdefault('xlabel', 'Fraction of subsampled reads [%]' if total is None else 'Number of subsampled reads [million]') + ax.plot( + [ + float(f) * total[sample] / 1e6 if total is not None else float(f) * 100 + for f in rarefaction.index + ], + rarefaction[sample], + label=sample, + ls=ls, + lw=lw, + color=colors.get(sample, None), + ) + + axparams.setdefault( + "title", "Rarefaction Analysis" + ) # [nr],{'fontsize':20}, loc='left', pad=10) + axparams.setdefault("ylabel", "Number of discovered Transcripts") + axparams.setdefault( + "xlabel", + ( + "Fraction of subsampled reads [%]" + if total is None + else "Number of subsampled reads [million]" + ), + ) ax.set(**axparams) if legend: ax.legend() return ax -def plot_str_var_number(str_var_count, group_name:'str', n_multi=10, fig_size=(12, 4), fig_title=None, **axparams): - ''' +def plot_str_var_number( + str_var_count, + group_name: "str", + n_multi=10, + fig_size=(12, 4), + fig_title=None, + **axparams, +): + """ Generates a figure with three barplots, depicting the number of genes with a certain number of structural variations, regarding distinct TSSs, exon chains and PASs in a gene. :param str_var_count: The count number of three categories of a group of interest, generated by Transcriptome.str_var_calculation(count_number=True). :param group_name: The name of the group that will be used to search for corresponding columns in group_str_var. :param \\**axparams: Additional keyword parameters are passed to ax.set(), eg: xlabel='xxx'. - ''' + """ fig, axs = plt.subplots(1, 3, figsize=fig_size) group_tab = str_var_count.loc[:, str_var_count.columns.str.startswith(group_name)] - - for i, feature in enumerate(group_tab.columns.str.split('_').str[-1].unique()): - n_feature_tab = group_tab.filter(regex=feature).value_counts(dropna=True).to_frame().sort_index().reset_index() - n_feature_tab.columns = ['n_feature', 'n_gene'] - - n_feature_mask = pd.concat([n_feature_tab[n_feature_tab['n_feature'] < n_multi], - pd.DataFrame({'n_feature': n_multi, 'n_gene': n_feature_tab[n_feature_tab['n_feature'] >= n_multi]['n_gene'].sum()}, index=[0])]) - - axs[i].bar(n_feature_mask['n_feature'], n_feature_mask['n_gene']) - - y = max(n_feature_mask['n_gene'].iloc[1:]) - maxy = max(max(n_feature_mask['n_gene'])*1.1, y*2) + feature_list = group_tab.columns.str.split("_").str[-1].unique().tolist() + + # update group_tab to avoid cases where group_name is a prefix of another group name + group_tab = group_tab.loc[:, [f"{group_name}_{f}" for f in feature_list]] + + for i, feature in enumerate(feature_list): + n_feature_tab = ( + group_tab.filter(regex=feature) + .value_counts(dropna=True) + .to_frame() + .sort_index() + .reset_index() + ) + n_feature_tab.columns = ["n_feature", "n_gene"] + + n_feature_mask = pd.concat( + [ + n_feature_tab[n_feature_tab["n_feature"] < n_multi], + pd.DataFrame( + { + "n_feature": n_multi, + "n_gene": n_feature_tab[n_feature_tab["n_feature"] >= n_multi][ + "n_gene" + ].sum(), + }, + index=[0], + ), + ] + ) + + axs[i].bar(n_feature_mask["n_feature"], n_feature_mask["n_gene"]) + + y = max(n_feature_mask["n_gene"].iloc[1:]) + maxy = max(max(n_feature_mask["n_gene"]) * 1.1, y * 2) axs[i].set_ylim(0, maxy) x = (n_multi + 2) / 2 - pct_multi = 1 - n_feature_mask['n_gene'].iloc[0] / n_feature_mask['n_gene'].sum() - props = {'connectionstyle':'bar, fraction=0.15','arrowstyle':'-',\ - 'shrinkA':10, 'shrinkB':10, 'linewidth':2} - - axs[i].text(x, y + (maxy * 0.2), f'{pct_multi:.2%}', ha='center') - axs[i].annotate('', xy=(2, y), xytext=(n_multi, y), arrowprops=props) - axs[i].set_xticks(range(1, n_multi+1)) - axs[i].set_xticklabels([j+1 for j in range(n_multi-1)] + ['>=' + str(n_multi)], rotation=20) - - axs[i].set_title(f"# {'exon_chain' if feature == 'ec' else feature.upper()} / gene", fontsize=10) - - if 'ylabel' not in axparams: - axs[0].set_ylabel('number of genes') + pct_multi = ( + 1 - n_feature_mask["n_gene"].iloc[0] / n_feature_mask["n_gene"].sum() + ) + props = { + "connectionstyle": "bar, fraction=0.15", + "arrowstyle": "-", + "shrinkA": 10, + "shrinkB": 10, + "linewidth": 2, + } + + axs[i].text(x, y + (maxy * 0.2), f"{pct_multi:.2%}", ha="center") + axs[i].annotate("", xy=(2, y), xytext=(n_multi, y), arrowprops=props) + axs[i].set_xticks(range(1, n_multi + 1)) + axs[i].set_xticklabels( + [j + 1 for j in range(n_multi - 1)] + [">=" + str(n_multi)], rotation=20 + ) + + axs[i].set_title( + f"# {'exon_chain' if feature == 'ec' else feature.upper()} / gene", + fontsize=10, + ) + + if "ylabel" not in axparams: + axs[0].set_ylabel("number of genes") axs[i].set(**{k: v for k, v in axparams.items() if k in axs[i].properties()}) @@ -414,7 +638,7 @@ def plot_str_var_number(str_var_count, group_name:'str', n_multi=10, fig_size=(1 def triangle_plot(str_var_tab, ax=None, colors=None, tax_title=None): - ''' + """ Generate a triangle plot from str_var_tab. Each row would be a dot in the triangle plot. There can be multiple sets of three columns. Every set must be in the order of TSS, exon chain and PAS, and named as "group_feature", eg: wt_tss, wt_ec, wt_pas. @@ -424,25 +648,29 @@ def triangle_plot(str_var_tab, ax=None, colors=None, tax_title=None): If a string, all dots would be colored in the same color. If a list, there should be only one group of structural variation, and the length of the list should be equal to the number of rows in str_var_tab. If a dict, the keys should be the group names, consistent with the prefix of columns in str_var_tab, and the values should be the colors. - ''' + """ - coords = str_var_tab.filter(regex='_(tss|ec|pas)') - assert all(coords.columns.str.contains('_')), 'name the columns as "group_feature", eg: wt_tss, wt_ec, wt_pas' + coords = str_var_tab.filter(regex="_(tss|ec|pas)") + assert all( + coords.columns.str.contains("_") + ), 'name the columns as "group_feature", eg: wt_tss, wt_ec, wt_pas' - groups = coords.columns.str.split('_').str[:-1].str.join('_').unique() + groups = coords.columns.str.split("_").str[:-1].str.join("_").unique() if colors is None: - color_scheme = {k: 'orange' for k in groups} + color_scheme = {k: "orange" for k in groups} elif isinstance(colors, str): color_scheme = {k: colors for k in groups} elif isinstance(colors, list): - assert len(colors) == len(coords), 'the length of colors should be equal to the number of rows in str_var_tab' + assert len(colors) == len( + coords + ), "the length of colors should be equal to the number of rows in str_var_tab" color_scheme = {k: colors for k in groups} elif isinstance(colors, dict): - assert all(k in colors for k in groups), 'not all groups have a defined color' + assert all(k in colors for k in groups), "not all groups have a defined color" color_scheme = colors else: - raise ValueError('colors must be a string, list, or dict') + raise ValueError("colors must be a string, list, or dict") scale = 1 if ax: @@ -450,37 +678,51 @@ def triangle_plot(str_var_tab, ax=None, colors=None, tax_title=None): else: _, tax = ternary.figure(scale=scale) - for gn in groups: - vals = coords.loc[:, coords.columns.str.startswith(gn)] - tax.scatter(vals.to_numpy()[:, [2, 1, 0]], color=color_scheme[gn], alpha=.5, label=gn) - tax.boundary(linewidth=1.5) tax.gridlines(multiple=0.25, linewidth=0.5) - tax.left_axis_label("TSS", fontsize=12, offset=0.12, weight='bold') - tax.right_axis_label("splicing ratio", fontsize=12, offset=0.12, weight='bold') - tax.bottom_axis_label("PAS", fontsize=12, offset=0.04, weight='bold') + tax.left_axis_label("TSS", fontsize=12, offset=0.12, weight="bold") + tax.right_axis_label("splicing ratio", fontsize=12, offset=0.12, weight="bold") + tax.bottom_axis_label("PAS", fontsize=12, offset=0.04, weight="bold") if tax_title: tax.set_title(tax_title, fontsize=14, pad=30) - tax.ticks(axis='lbr', linewidth=1, multiple=0.25, offset=0.02, tick_formats="%.2f") + tax.ticks(axis="lbr", linewidth=1, multiple=0.25, offset=0.02, tick_formats="%.2f") # label different areas - tax.horizontal_line(0.5, linewidth=3, color='palevioletred', linestyle="-.") - tax.top_corner_label("splicing high", color='palevioletred', fontsize=12, offset=0.18, weight='bold') - tax.left_parallel_line(0.5, linewidth=3, color='olivedrab', linestyle="-.") - tax.right_corner_label("PAS high", position=(1.0, 0.05, 0), color='olivedrab', fontsize=12, weight='bold') - tax.right_parallel_line(0.5, linewidth=3, color='cornflowerblue', linestyle="-.") - tax.left_corner_label("TSS high", color='cornflowerblue', fontsize=12, offset=0.2, weight='bold') - tax.scatter([[1/3, 1/3, 1/3]], marker='*', color='saddlebrown', s=120) # simple + tax.horizontal_line(0.5, linewidth=3, color="palevioletred", linestyle="-.") + tax.top_corner_label( + "splicing high", color="palevioletred", fontsize=12, offset=0.18, weight="bold" + ) + tax.left_parallel_line(0.5, linewidth=3, color="olivedrab", linestyle="-.") + tax.right_corner_label( + "PAS high", + position=(1.0, 0.05, 0), + color="olivedrab", + fontsize=12, + weight="bold", + ) + tax.right_parallel_line(0.5, linewidth=3, color="cornflowerblue", linestyle="-.") + tax.left_corner_label( + "TSS high", color="cornflowerblue", fontsize=12, offset=0.2, weight="bold" + ) + tax.scatter( + [[1 / 3, 1 / 3, 1 / 3]], marker="*", color="saddlebrown", s=120 + ) # simple tax.set_background_color(color="whitesmoke", alpha=0.7) + for gn in groups: + vals = coords.loc[:, coords.columns.str.startswith(gn)] + tax.scatter( + vals.to_numpy()[:, [2, 1, 0]], color=color_scheme[gn], alpha=0.7, label=gn + ) + if isinstance(colors, dict): - tax.legend(title=None, fontsize=10, facecolor='white', frameon=True) + tax.legend(title=None, fontsize=10, facecolor="white", frameon=True) # remove default matplotlib axes tax.clear_matplotlib_ticks() - tax.get_axes().axis('off') + tax.get_axes().axis("off") - return tax \ No newline at end of file + return tax diff --git a/src/isotools/run_isotools.py b/src/isotools/run_isotools.py index a28759d..6f63708 100644 --- a/src/isotools/run_isotools.py +++ b/src/isotools/run_isotools.py @@ -1,6 +1,7 @@ import isotools import matplotlib.pyplot as plt import argparse + # import numpy as np import pandas as pd from isotools import Transcriptome @@ -10,39 +11,122 @@ import sys -logger = logging.getLogger('run_isotools') +logger = logging.getLogger("run_isotools") def argument_parser(): - parser = argparse.ArgumentParser(prog='isotools', description='process LRTS data with isotool') - parser.add_argument('--anno', metavar='', help='specify reference annotation') - parser.add_argument('--genome', metavar='', help='specify reference genome file') - parser.add_argument('--samples', metavar='', help='add samples from sample tsv') - parser.add_argument('--file_prefix', metavar='', default='./', help='Specify output path and prefix.') - parser.add_argument('--file_suffix', metavar='', help='Specify output sufix (not used for pickle).') - parser.add_argument('--short_read_samples', metavar='', help='Specify tsv with short read samples.') - parser.add_argument('--force_recreate', help='reimport transcriptomes from alignments, even in presence of pickle file.', action='store_true') - parser.add_argument('--no_pickle', help='Do not pickle the transcriptome for later use.', action='store_true') - parser.add_argument('--progress_bar', help='Show the progress of individual tasks.', action='store_true') - parser.add_argument("-l", "--log", dest="logLevel", default='INFO', - choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL', None], help="Set the logging level.") - parser.add_argument('--group_by', metavar='', - help='specify column used for grouping the samples. This applies to \ - --qc_plots, --altsplice_stats, --diff, --diff_plots and --altsplice_plots', - default='name') - parser.add_argument('--custom_filter_tag', metavar='', help='add custom filter tag', nargs='*') - parser.add_argument('--filter_query', metavar='<"expression">', default='FSM or not (INTERNAL_PRIMING or RTTS)', - help='filter the transcripts used in gtf and table output') - parser.add_argument('--qc_plots', help='make qc plots', action='store_true') - parser.add_argument('--altsplice_stats', help='alternative splicing barplots', action='store_true') - parser.add_argument('--transcript_table', help='make transcript_table', action='store_true') - parser.add_argument('--gtf_out', help='make filtered gtf', action='store_true') - parser.add_argument('--diff', metavar='', nargs='*', help='perform differential splicing analysis') - parser.add_argument('--diff_plots', metavar='', type=int, help='make sashimi plots for top differential genes') - parser.add_argument('--plot_type', metavar='', type=str, default='png', choices=['png', 'pdf', 'svg', 'eps', 'pgf', 'ps']) - parser.add_argument('--plot_dpi', metavar='', type=int, default=100, help='Specify resolution of plots') - parser.add_argument('--altsplice_plots', metavar='', type=int, - help='make sashimi plots for top covered alternative spliced genes for each category') + parser = argparse.ArgumentParser( + prog="isotools", description="process LRTS data with isotool" + ) + parser.add_argument( + "--anno", + metavar="", + help="specify reference annotation", + ) + parser.add_argument( + "--genome", metavar="", help="specify reference genome file" + ) + parser.add_argument( + "--samples", metavar="", help="add samples from sample tsv" + ) + parser.add_argument( + "--file_prefix", + metavar="", + default="./", + help="Specify output path and prefix.", + ) + parser.add_argument( + "--file_suffix", + metavar="", + help="Specify output sufix (not used for pickle).", + ) + parser.add_argument( + "--short_read_samples", + metavar="", + help="Specify tsv with short read samples.", + ) + parser.add_argument( + "--force_recreate", + help="reimport transcriptomes from alignments, even in presence of pickle file.", + action="store_true", + ) + parser.add_argument( + "--no_pickle", + help="Do not pickle the transcriptome for later use.", + action="store_true", + ) + parser.add_argument( + "--progress_bar", + help="Show the progress of individual tasks.", + action="store_true", + ) + parser.add_argument( + "-l", + "--log", + dest="logLevel", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", None], + help="Set the logging level.", + ) + parser.add_argument( + "--group_by", + metavar="", + help="specify column used for grouping the samples. This applies to \ + --qc_plots, --altsplice_stats, --diff, --diff_plots and --altsplice_plots", + default="name", + ) + parser.add_argument( + "--custom_filter_tag", + metavar='', + help="add custom filter tag", + nargs="*", + ) + parser.add_argument( + "--filter_query", + metavar='<"expression">', + default="FSM or not (INTERNAL_PRIMING or RTTS)", + help="filter the transcripts used in gtf and table output", + ) + parser.add_argument("--qc_plots", help="make qc plots", action="store_true") + parser.add_argument( + "--altsplice_stats", help="alternative splicing barplots", action="store_true" + ) + parser.add_argument( + "--transcript_table", help="make transcript_table", action="store_true" + ) + parser.add_argument("--gtf_out", help="make filtered gtf", action="store_true") + parser.add_argument( + "--diff", + metavar="", + nargs="*", + help="perform differential splicing analysis", + ) + parser.add_argument( + "--diff_plots", + metavar="", + type=int, + help="make sashimi plots for top differential genes", + ) + parser.add_argument( + "--plot_type", + metavar="", + type=str, + default="png", + choices=["png", "pdf", "svg", "eps", "pgf", "ps"], + ) + parser.add_argument( + "--plot_dpi", + metavar="", + type=int, + default=100, + help="Specify resolution of plots", + ) + parser.add_argument( + "--altsplice_plots", + metavar="", + type=int, + help="make sashimi plots for top covered alternative spliced genes for each category", + ) return parser @@ -50,17 +134,20 @@ def main(): parser = argument_parser() args = parser.parse_args() - plt.rcParams['savefig.dpi'] = args.plot_dpi + plt.rcParams["savefig.dpi"] = args.plot_dpi if args.logLevel: - logging.basicConfig(level=getattr(logging, args.logLevel), format='%(asctime)s %(levelname)s: %(message)s', - datefmt='%Y-%m-%d %H:%M:%S') - logger.info('This is isotools version %s', isotools.__version__) - logger.debug('arguments: %s', args) + logging.basicConfig( + level=getattr(logging, args.logLevel), + format="%(asctime)s %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + logger.info("This is isotools version %s", isotools.__version__) + logger.debug("arguments: %s", args) if args.file_suffix is None: - file_suffix = '' + file_suffix = "" else: - file_suffix = '_'+args.file_suffix + file_suffix = "_" + args.file_suffix try: isoseq = load_isoseq(args) @@ -70,54 +157,93 @@ def main(): exit(1) groups = isoseq.groups(args.group_by) - logger.debug('sample group definition: %s', groups) + logger.debug("sample group definition: %s", groups) if args.short_read_samples: illu_samples = pd.read_csv(args.short_read_samples) - isoseq.add_short_read_coverage(dict(zip(illu_samples['name'], illu_samples['file_name']))) + isoseq.add_short_read_coverage( + dict(zip(illu_samples["name"], illu_samples["file_name"])) + ) illu_groups = {} - if 'short_reads' in isoseq.infos: # todo: make this optional/parameter --dont_use_short_reads + if ( + "short_reads" in isoseq.infos + ): # todo: make this optional/parameter --dont_use_short_reads for grp, sample in groups.items(): - if sample in isoseq.infos['short_reads']['name']: - i = pd.Index(isoseq.infos['short_reads']['name']).get_loc(sample) + if sample in isoseq.infos["short_reads"]["name"]: + i = pd.Index(isoseq.infos["short_reads"]["name"]).get_loc(sample) illu_groups.setdefault(grp, []).append(i) - logger.debug('illumina sample group definition: %s\n%s', illu_groups, isoseq.infos["short_reads"]) + logger.debug( + "illumina sample group definition: %s\n%s", + illu_groups, + isoseq.infos["short_reads"], + ) if args.custom_filter_tag is not None: for f_def in args.custom_filter_tag: - tag, f_expr = f_def.split('=', 1) - if tag not in isoseq.filter['transcript']: - logger.info('adding new filter rule %s in transcript context', tag) - isoseq.add_filter(tag, f_expr, context='transcript', update=True) + tag, f_expr = f_def.split("=", 1) + if tag not in isoseq.filter["transcript"]: + logger.info("adding new filter rule %s in transcript context", tag) + isoseq.add_filter(tag, f_expr, context="transcript", update=True) if args.transcript_table: - trtab_fn = f'{args.file_prefix}_transcripts{file_suffix}.csv' - logger.info('writing transcript table to %s', trtab_fn) - df = isoseq.transcript_table(groups=groups, coverage=True, tpm=True, query=args.filter_query, progress_bar=args.progress_bar) + trtab_fn = f"{args.file_prefix}_transcripts{file_suffix}.csv" + logger.info("writing transcript table to %s", trtab_fn) + df = isoseq.transcript_table( + groups=groups, + coverage=True, + tpm=True, + query=args.filter_query, + progress_bar=args.progress_bar, + ) df.to_csv(trtab_fn) if args.gtf_out: - gtf_fn = f'{args.file_prefix}_transcripts{file_suffix}.gtf' - isoseq.write_gtf(gtf_fn, query=args.filter_query, progress_bar=args.progress_bar) + gtf_fn = f"{args.file_prefix}_transcripts{file_suffix}.gtf" + isoseq.write_gtf( + gtf_fn, query=args.filter_query, progress_bar=args.progress_bar + ) if args.qc_plots: - filter_plots(isoseq, groups, f'{args.file_prefix}_filter_stats{file_suffix}.{args.plot_type}', args.progress_bar) - transcript_plots(isoseq, groups, f'{args.file_prefix}_transcript_stats{file_suffix}.{args.plot_type}', args.progress_bar) + filter_plots( + isoseq, + groups, + f"{args.file_prefix}_filter_stats{file_suffix}.{args.plot_type}", + args.progress_bar, + ) + transcript_plots( + isoseq, + groups, + f"{args.file_prefix}_transcript_stats{file_suffix}.{args.plot_type}", + args.progress_bar, + ) if args.altsplice_stats: - altsplice_plots(isoseq, groups, f'{args.file_prefix}_altsplice{file_suffix}.{args.plot_type}', args.progress_bar) + altsplice_plots( + isoseq, + groups, + f"{args.file_prefix}_altsplice{file_suffix}.{args.plot_type}", + args.progress_bar, + ) if args.altsplice_plots: examples = altsplice_examples(isoseq, args.altsplice_plots) - plot_altsplice_examples(isoseq, groups, illu_groups, examples, args.file_prefix, file_suffix, args.plot_type) + plot_altsplice_examples( + isoseq, + groups, + illu_groups, + examples, + args.file_prefix, + file_suffix, + args.plot_type, + ) if args.diff is not None: test_differential(isoseq, groups, illu_groups, args, file_suffix) if not args.no_pickle: - logger.info('saving transcripts as pickle file') - isoseq.save(args.file_prefix+'_isotools.pkl') + logger.info("saving transcripts as pickle file") + isoseq.save(args.file_prefix + "_isotools.pkl") def load_isoseq(args): @@ -126,30 +252,34 @@ def load_isoseq(args): # if sample_tab is specified, genome must be specified if not args.force_recreate: try: - isoseq = Transcriptome.load(args.file_prefix+'_isotools.pkl') + isoseq = Transcriptome.load(args.file_prefix + "_isotools.pkl") except FileNotFoundError: if args.samples is None: - raise ValueError('No samples specified') + raise ValueError("No samples specified") if args.samples: if args.anno is None or args.genome is None: - raise ValueError('to add samples, genome and annotation must be provided.') + raise ValueError("to add samples, genome and annotation must be provided.") if isoseq is None: - isoseq = Transcriptome.from_reference(args.anno, progress_bar=args.progress_bar) + isoseq = Transcriptome.from_reference( + args.anno, progress_bar=args.progress_bar + ) isoseq.collapse_immune_genes() added = False - sample_tab = pd.read_csv(args.samples, sep='\t') - if 'sample_name' not in sample_tab.columns: + sample_tab = pd.read_csv(args.samples, sep="\t") + if "sample_name" not in sample_tab.columns: logger.debug(sample_tab.columns) raise ValueError('No "sample_name" column found in sample table') - if 'file_name' not in sample_tab.columns: + if "file_name" not in sample_tab.columns: raise ValueError('No "file_name" column found in sample table') for _, row in sample_tab.iterrows(): - if row['sample_name'] in isoseq.samples: - logger.info('skipping already present sample %s', row["sample_name"]) + if row["sample_name"] in isoseq.samples: + logger.info("skipping already present sample %s", row["sample_name"]) continue - sample_args = {k: v for k, v in row.items() if k != 'file_name'} - isoseq.add_sample_from_bam(fn=row.file_name, progress_bar=args.progress_bar, **sample_args) + sample_args = {k: v for k, v in row.items() if k != "file_name"} + isoseq.add_sample_from_bam( + fn=row.file_name, progress_bar=args.progress_bar, **sample_args + ) added = True if added: isoseq.add_qc_metrics(args.genome, progress_bar=args.progress_bar) @@ -159,81 +289,172 @@ def load_isoseq(args): def filter_plots(isoseq: Transcriptome, groups, filename, progress_bar): - logger.info('filter statistics plots') - f_stats = isoseq.filter_stats(groups=groups, weight_by_coverage=True, min_coverage=1, progress_bar=progress_bar) - plt.rcParams["figure.figsize"] = (15+5*len(groups), 7) + logger.info("filter statistics plots") + f_stats = isoseq.filter_stats( + groups=groups, + weight_by_coverage=True, + min_coverage=1, + progress_bar=progress_bar, + ) + plt.rcParams["figure.figsize"] = (15 + 5 * len(groups), 7) fig, ax = plt.subplots() isotools.plots.plot_bar(f_stats[0], ax=ax, **f_stats[1]) - fig.tight_layout(rect=[0, 0, 1, .95]) + fig.tight_layout(rect=[0, 0, 1, 0.95]) fig.savefig(filename) def transcript_plots(isoseq: Transcriptome, groups, filename, progress_bar): - logger.info('perparing summary of quality control metrics...') - logger.info('1) Number of RTTS, fragmentation and internal priming artefacts') - f_stats = isoseq.filter_stats(groups=groups, weight_by_coverage=True, min_coverage=1, - progress_bar=progress_bar, tags=('RTTS', 'FRAGMENT', 'INTERNAL_PRIMING')) + logger.info("perparing summary of quality control metrics...") + logger.info("1) Number of RTTS, fragmentation and internal priming artefacts") + f_stats = isoseq.filter_stats( + groups=groups, + weight_by_coverage=True, + min_coverage=1, + progress_bar=progress_bar, + tags=("RTTS", "FRAGMENT", "INTERNAL_PRIMING"), + ) tr_stats = [] - logger.info('2) Transcript length distribution') - tr_stats.append(isoseq.transcript_length_hist(groups=groups, add_reference=True, min_coverage=2, transcript_filter=dict(query='FSM', progress_bar=progress_bar))) - logger.info('3) Distribution of downstream A fraction in known genes') - tr_stats.append(isoseq.downstream_a_hist(groups=groups, transcript_filter=dict( - query='not (NOVEL_GENE or UNSPLICED)', progress_bar=progress_bar), ref_filter=dict(query='not UNSPLICED'))) - logger.info('4) Distribution of downstream A fraction in novel genes') - tr_stats.append(isoseq.downstream_a_hist(groups=groups, transcript_filter=dict(query='NOVEL_GENE and UNSPLICED', progress_bar=progress_bar))) - logger.info('5) Distribution of direct repeats') - tr_stats.append(isoseq.direct_repeat_hist(groups=groups, transcript_filter=dict(progress_bar=progress_bar))) - tr_stats.append((pd.concat([tr_stats[2][0].add_suffix(' novel unspliced'), tr_stats[1][0].add_suffix(' known multiexon')], axis=1), tr_stats[2][1])) + logger.info("2) Transcript length distribution") + tr_stats.append( + isoseq.transcript_length_hist( + groups=groups, + add_reference=True, + min_coverage=2, + transcript_filter=dict(query="FSM", progress_bar=progress_bar), + ) + ) + logger.info("3) Distribution of downstream A fraction in known genes") + tr_stats.append( + isoseq.downstream_a_hist( + groups=groups, + transcript_filter=dict( + query="not (NOVEL_GENE or UNSPLICED)", progress_bar=progress_bar + ), + ref_filter=dict(query="not UNSPLICED"), + ) + ) + logger.info("4) Distribution of downstream A fraction in novel genes") + tr_stats.append( + isoseq.downstream_a_hist( + groups=groups, + transcript_filter=dict( + query="NOVEL_GENE and UNSPLICED", progress_bar=progress_bar + ), + ) + ) + logger.info("5) Distribution of direct repeats") + tr_stats.append( + isoseq.direct_repeat_hist( + groups=groups, transcript_filter=dict(progress_bar=progress_bar) + ) + ) + tr_stats.append( + ( + pd.concat( + [ + tr_stats[2][0].add_suffix(" novel unspliced"), + tr_stats[1][0].add_suffix(" known multiexon"), + ], + axis=1, + ), + tr_stats[2][1], + ) + ) plt.rcParams["figure.figsize"] = (30, 25) - plt.rcParams.update({'font.size': 14}) + plt.rcParams.update({"font.size": 14}) fig, axs = plt.subplots(3, 2) # A) transcript length isotools.plots.plot_distr(tr_stats[0][0], smooth=3, ax=axs[0, 0], **tr_stats[0][1]) # D) frequency of artifacts - isotools.plots.plot_bar(f_stats[0], ax=axs[0, 1], drop_categories=['PASS'], **f_stats[1]) + isotools.plots.plot_bar( + f_stats[0], ax=axs[0, 1], drop_categories=["PASS"], **f_stats[1] + ) # B) internal priming - isotools.plots.plot_distr(tr_stats[4][0][[c for c in tr_stats[4][0].columns if 'novel' in c]], - smooth=3, ax=axs[1, 0], density=True, fill=True, **tr_stats[4][1]) - isotools.plots.plot_distr(tr_stats[4][0][[c for c in tr_stats[4][0].columns if 'known' in c]], - smooth=3, ax=axs[1, 1], density=True, fill=True, **tr_stats[4][1]) + isotools.plots.plot_distr( + tr_stats[4][0][[c for c in tr_stats[4][0].columns if "novel" in c]], + smooth=3, + ax=axs[1, 0], + density=True, + fill=True, + **tr_stats[4][1], + ) + isotools.plots.plot_distr( + tr_stats[4][0][[c for c in tr_stats[4][0].columns if "known" in c]], + smooth=3, + ax=axs[1, 1], + density=True, + fill=True, + **tr_stats[4][1], + ) # C) RTTS - isotools.plots.plot_distr(tr_stats[3][0][[c for c in tr_stats[3][0].columns if 'novel' in c]], ax=axs[2, 0], density=True, **tr_stats[3][1]) - isotools.plots.plot_distr(tr_stats[3][0][[c for c in tr_stats[3][0].columns if 'known' in c]], ax=axs[2, 1], density=True, **tr_stats[3][1]) - fig.tight_layout(rect=[0, 0, 1, .95]) + isotools.plots.plot_distr( + tr_stats[3][0][[c for c in tr_stats[3][0].columns if "novel" in c]], + ax=axs[2, 0], + density=True, + **tr_stats[3][1], + ) + isotools.plots.plot_distr( + tr_stats[3][0][[c for c in tr_stats[3][0].columns if "known" in c]], + ax=axs[2, 1], + density=True, + **tr_stats[3][1], + ) + fig.tight_layout(rect=[0, 0, 1, 0.95]) fig.savefig(filename) def altsplice_plots(isoseq: Transcriptome, groups, filename, progress_bar): - logger.info('preparing novel splicing statistics...') - altsplice = isoseq.altsplice_stats(groups=groups, transcript_filter=dict(query='not (RTTS or INTERNAL_PRIMING)', progress_bar=progress_bar)) + logger.info("preparing novel splicing statistics...") + altsplice = isoseq.altsplice_stats( + groups=groups, + transcript_filter=dict( + query="not (RTTS or INTERNAL_PRIMING)", progress_bar=progress_bar + ), + ) - plt.rcParams["figure.figsize"] = (15+5*len(groups), 10) + plt.rcParams["figure.figsize"] = (15 + 5 * len(groups), 10) fig, ax = plt.subplots() - isotools.plots.plot_bar(altsplice[0], ax=ax, drop_categories=['FSM'], **altsplice[1]) - fig.tight_layout(rect=[0, 0, 1, .95]) + isotools.plots.plot_bar( + altsplice[0], ax=ax, drop_categories=["FSM"], **altsplice[1] + ) + fig.tight_layout(rect=[0, 0, 1, 0.95]) fig.savefig(filename) -def altsplice_examples(isoseq: Transcriptome, n, query='not FSM'): # return the top n covered genes for each category +def altsplice_examples( + isoseq: Transcriptome, n, query="not FSM" +): # return the top n covered genes for each category examples = {} - for gene, transcript_ids, transcripts in isoseq.iter_transcripts(query=query, genewise=True): + for gene, transcript_ids, transcripts in isoseq.iter_transcripts( + query=query, genewise=True + ): total_cov = gene.coverage.sum() for transcript_id, transcript in zip(transcript_ids, transcripts): cov = gene.coverage[:, transcript_id].sum() - score = cov*cov/total_cov - for cat in transcript['annotation'][1]: - examples.setdefault(cat, []).append((score, gene.name, gene.id, transcript_id, cov, total_cov)) + score = cov * cov / total_cov + for cat in transcript["annotation"][1]: + examples.setdefault(cat, []).append( + (score, gene.name, gene.id, transcript_id, cov, total_cov) + ) examples = {k: sorted(v, key=lambda x: -x[0]) for k, v in examples.items()} return {k: v[:n] for k, v in examples.items()} -def plot_altsplice_examples(isoseq: Transcriptome, groups, illu_groups, examples, file_prefix, file_suffix, plot_type): - nplots = len(groups)+1 +def plot_altsplice_examples( + isoseq: Transcriptome, + groups, + illu_groups, + examples, + file_prefix, + file_suffix, + plot_type, +): + nplots = len(groups) + 1 # sample_idx = {r: i for i, r in enumerate(isoseq.infos['sample_table'].name)} if illu_groups: # illu_sample_idx = {r: i for i, r in enumerate(isoseq.infos['illumina_fn'])} @@ -241,46 +462,54 @@ def plot_altsplice_examples(isoseq: Transcriptome, groups, illu_groups, examples illu_groups = {gn: illu_groups[gn] for gn in groups if gn in illu_groups} nplots += len(illu_groups) # illumina is a dict with bam filenames - plt.rcParams["figure.figsize"] = (20, 5*nplots) + plt.rcParams["figure.figsize"] = (20, 5 * nplots) for cat, best_list in examples.items(): logger.debug(cat + str(best_list)) - for i, (score, gene_name, gene_id, transcript_id, cov, total_cov) in enumerate(best_list): + for i, (_score, gene_name, gene_id, transcript_id, cov, total_cov) in enumerate( + best_list + ): gene = isoseq[gene_id] try: info = gene.transcripts[transcript_id]["annotation"][1][cat] except TypeError: info = list() - logger.info(f'{i + 1}. best example for {cat}: {gene_name} {transcript_id} {info}, {cov} {total_cov} ({cov/total_cov:%})') + logger.info( + f"{i + 1}. best example for {cat}: {gene_name} {transcript_id} {info}, {cov} {total_cov} ({cov/total_cov:%})" + ) joi = [] # set joi if info: junctions = [] - if cat == 'exon skipping': - exons = gene.transcripts[transcript_id]['exons'] + if cat == "exon skipping": + exons = gene.transcripts[transcript_id]["exons"] for pos in info: idx = next(i for i, e in enumerate(exons) if e[0] > pos[0]) junctions.append((exons[idx - 1][1], exons[idx][0])) info = junctions - elif cat == 'novel exon': - exons = gene.transcripts[transcript_id]['exons'] + elif cat == "novel exon": + exons = gene.transcripts[transcript_id]["exons"] for i, exon in enumerate(exons[1:-1]): if exon in info: - junctions.extend([(exons[i][1], exon[0]), (exon[1], exons[i+2][0])]) - elif cat == 'novel junction': + junctions.extend( + [(exons[i][1], exon[0]), (exon[1], exons[i + 2][0])] + ) + elif cat == "novel junction": junctions = info for pos in junctions: try: if len(pos) == 2 and all(isinstance(x, int) for x in pos): - joi.append(tuple(pos)) # if this is a junction, it gets highlighed in the plot + joi.append( + tuple(pos) + ) # if this is a junction, it gets highlighed in the plot except TypeError: pass - print(f'junctions of interest: {joi}') + print(f"junctions of interest: {joi}") fig, axs = gene.sashimi_figure(samples=groups, junctions_of_interest=joi) fig.tight_layout() - stem = f'{file_prefix}_altsplice{file_suffix}_{cat.replace(" ","_").replace("/","_")}_{gene.name}' - fig.savefig(f'{stem}_sashimi.{plot_type}') + stem = f"{file_prefix}_altsplice{file_suffix}_{cat.replace(' ', '_').replace('/', '_')}_{gene.name}" + fig.savefig(f"{stem}_sashimi.{plot_type}") # zoom if info: for pos in info: @@ -288,69 +517,98 @@ def plot_altsplice_examples(isoseq: Transcriptome, groups, illu_groups, examples start, end = pos, pos elif len(pos) == 2 and all(isinstance(x, int) for x in pos): if pos[1] < pos[0]: - start, end = sorted([pos[0], pos[0]+pos[1]]) + start, end = sorted([pos[0], pos[0] + pos[1]]) else: start, end = pos else: continue for a in axs: a.set_xlim((start - 100, end + 100)) - axs[0].set_title(f'{gene.name} {gene.chrom}:{start}-{end} {cat} (cov={cov})') + axs[0].set_title( + f"{gene.name} {gene.chrom}:{start}-{end} {cat} (cov={cov})" + ) - plt.savefig(f'{stem}_zoom_{start}_{end}_sashimi.{plot_type}') + plt.savefig(f"{stem}_zoom_{start}_{end}_sashimi.{plot_type}") plt.close() -def plot_diffsplice(isoseq: Transcriptome, de_tab, groups, illu_gr, file_prefix, plot_type): +def plot_diffsplice( + isoseq: Transcriptome, de_tab, groups, illu_gr, file_prefix, plot_type +): - nplots = len(groups)+1 + nplots = len(groups) + 1 # sample_idx = {r: i for i, r in enumerate(isoseq.infos['sample_table'].name)} if illu_gr: # illu_sample_idx = {r: i for i, r in enumerate(isoseq.infos['illumina_fn'])} # todo: add illumina nplots += len(illu_gr) - plt.rcParams["figure.figsize"] = (7, 2*nplots) - for gene_id in de_tab['gene_id'].unique(): + plt.rcParams["figure.figsize"] = (7, 2 * nplots) + for gene_id in de_tab["gene_id"].unique(): gene = isoseq[gene_id] - logger.info(f'sashimi plot for differentially spliced gene {gene.name}') + logger.info(f"sashimi plot for differentially spliced gene {gene.name}") joi = [] - for _, regOI in de_tab.loc[de_tab['gene_id'] == gene_id].iterrows(): + for _, regOI in de_tab.loc[de_tab["gene_id"] == gene_id].iterrows(): # trA, trB = (list(map(int, regOI[i][1:-1].split(', '))) for i in ('trA', 'trB')) - transcriptA, transcriptB = list(regOI['trA']), list(regOI['trB']) + transcriptA, transcriptB = list(regOI["trA"]), list(regOI["trB"]) transcriptA.sort(key=lambda x: -gene.coverage[:, x].sum()) transcriptB.sort(key=lambda x: -gene.coverage[:, x].sum()) - joi.extend([(exon1[1], exon2[0]) for exon1, exon2 in pairwise(gene.transcripts[transcriptA[0]]['exons']) if exon1[1] >= regOI.start and exon2[0] <= regOI.end]) - joi.extend([(exon1[1], exon2[0]) for exon1, exon2 in pairwise(gene.transcripts[transcriptB[0]]['exons']) if exon1[1] >= regOI.start and exon2[0] <= regOI.end]) + joi.extend( + [ + (exon1[1], exon2[0]) + for exon1, exon2 in pairwise( + gene.transcripts[transcriptA[0]]["exons"] + ) + if exon1[1] >= regOI.start and exon2[0] <= regOI.end + ] + ) + joi.extend( + [ + (exon1[1], exon2[0]) + for exon1, exon2 in pairwise( + gene.transcripts[transcriptB[0]]["exons"] + ) + if exon1[1] >= regOI.start and exon2[0] <= regOI.end + ] + ) fig, axs = gene.sashimi_figure(samples=groups, junctions_of_interest=joi) fig.tight_layout() fig.savefig(f'{file_prefix}_{"_".join(groups)}_{gene.name}_sashimi.{plot_type}') # zoom - for i, row in de_tab.loc[de_tab.gene == gene.name].iterrows(): + for _, row in de_tab.loc[de_tab.gene == gene.name].iterrows(): if row.start > gene.start and row.end < gene.end: for a in axs: a.set_xlim((row.start - 1000, row.end + 1000)) - axs[0].set_title(f'{gene.name} {gene.chrom}:{row.start}-{row.end}') - fig.savefig(f'{file_prefix}_{"_".join(groups)}_{gene.name}_zoom_{row.start}_{row.end}_sashimi.{plot_type}') + axs[0].set_title(f"{gene.name} {gene.chrom}:{row.start}-{row.end}") + fig.savefig( + f'{file_prefix}_{"_".join(groups)}_{gene.name}_zoom_{row.start}_{row.end}_sashimi.{plot_type}' + ) plt.close(fig) def test_differential(isoseq: Transcriptome, groups, illu_groups, args, file_suffix): - file_prefix = f'{args.file_prefix}_diff{file_suffix}' + file_prefix = f"{args.file_prefix}_diff{file_suffix}" for diff_cmp in args.diff: - gr = diff_cmp.split('/') - logger.debug(f'processing {gr}') + gr = diff_cmp.split("/") + logger.debug(f"processing {gr}") if len(gr) != 2: - logger.error('--diff argument format error: provide two groups separated by "/" -- skipping') + logger.error( + '--diff argument format error: provide two groups separated by "/" -- skipping' + ) continue if not all(gn in groups for gn in gr): logger.error( - f'--diff argument format error: group names {[gn for gn in gr if gn not in groups]} not found in sample table -- skipping') + f"--diff argument format error: group names {[gn for gn in gr if gn not in groups]} not found in sample table -- skipping" + ) continue gr = {gn: groups[gn] for gn in gr} - res = isoseq.altsplice_test(gr, progress_bar=args.progress_bar).sort_values('pvalue') - sig = res.padj < .1 - logger.info(f'{sum(sig)} differential splice sites in {len(res.loc[sig,"gene"].unique())} genes for {" vs ".join(gr)}') + res = isoseq.altsplice_test(gr, progress_bar=args.progress_bar).sort_values( + "pvalue" + ) + sig = res.padj < 0.1 + logger.info( + f'{sum(sig)} differential splice sites in {len(res.loc[sig, "gene"].unique())} genes for {" vs ".join(gr)}' + ) res.to_csv(f'{file_prefix}_{"_".join(gr)}.csv', index=False) if args.diff_plots is not None: @@ -364,18 +622,38 @@ def test_differential(isoseq: Transcriptome, groups, illu_groups, args, file_suf cov = isoseq[gene].illumina_coverage for gi, grp_n in enumerate(gr): if grp_n not in illu_groups: - j_cov[gi] = 'NA' + j_cov[gi] = "NA" for sn in illu_groups[grp_n]: i = illu_groups[sn] for k, v in cov[i].junctions.items(): - j_cov[gi][k] = j_cov[gi].get(k, 0)+v + j_cov[gi][k] = j_cov[gi].get(k, 0) + v j_cov[gi].setdefault(ji, 0) - illu_cov.append((j_cov[0][ji], j_cov[1][ji], max(j_cov[0].values()), max(j_cov[1].values()))) - illu_cov = {k: v for k, v in zip(['illu_cov1', 'illu_cov2', 'illu_max1', 'illu_max2'], zip(*illu_cov))} + illu_cov.append( + ( + j_cov[0][ji], + j_cov[1][ji], + max(j_cov[0].values()), + max(j_cov[1].values()), + ) + ) + illu_cov = { + k: v + for k, v in zip( + ["illu_cov1", "illu_cov2", "illu_max1", "illu_max2"], + zip(*illu_cov), + ) + } sig_tab = sig_tab.assign(**illu_cov) sig_tab.to_csv(f'{file_prefix}_top_{"_".join(gr)}.csv') - plot_diffsplice(isoseq, res.head(args.diff_plots), gr, illu_groups, file_prefix, args.plot_type) + plot_diffsplice( + isoseq, + res.head(args.diff_plots), + gr, + illu_groups, + file_prefix, + args.plot_type, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/isotools/short_read.py b/src/isotools/short_read.py index d34b117..63d9f81 100644 --- a/src/isotools/short_read.py +++ b/src/isotools/short_read.py @@ -3,40 +3,52 @@ from ._utils import junctions_from_cigar import numpy as np import logging -logger = logging.getLogger('isotools') + +logger = logging.getLogger("isotools") class Coverage: - 'stores the illumina read coverage of a gene' + "stores the illumina read coverage of a gene" + # plan: make a binned version, or use run length encoding def __init__(self, cov, junctions, offset, chrom=None): self._cov = cov self._junctions = junctions - self.reg = None if cov is None else (chrom, offset, offset+len(cov)) + self.reg = None if cov is None else (chrom, offset, offset + len(cov)) self.bam_fn = None @classmethod def from_bam(cls, bam_fn, gene, load=False): - 'assign the bam file' + "assign the bam file" if load: - with AlignmentFile(bam_fn, 'rb') as align: + with AlignmentFile(bam_fn, "rb") as align: return cls.from_alignment(gene, align) else: # load on demand obj = cls.__new__(cls) obj._cov = None obj._junctions = None obj.bam_fn = bam_fn - start = min(gene.start, *[transcript['exons'][0][0] for transcript in gene.transcripts]) - end = max(gene.end, *[transcript['exons'][-1][1] for transcript in gene.transcripts]) + start = min( + gene.start, + *[transcript["exons"][0][0] for transcript in gene.transcripts], + ) + end = max( + gene.end, + *[transcript["exons"][-1][1] for transcript in gene.transcripts], + ) obj.reg = (gene.chrom, start, end) return obj @classmethod def from_alignment(cls, align_fh, gene): - 'load the coverage from bam file' - start = min(gene.start, *[transcript['exons'][0][0] for transcript in gene.transcripts]) - end = max(gene.end, *[transcript['exons'][-1][1] for transcript in gene.transcripts]) + "load the coverage from bam file" + start = min( + gene.start, *[transcript["exons"][0][0] for transcript in gene.transcripts] + ) + end = max( + gene.end, *[transcript["exons"][-1][1] for transcript in gene.transcripts] + ) cov, junctions = cls._import_coverage(align_fh, (gene.chrom, start, end)) obj = cls.__new__(cls) obj.__init__(cov, junctions, start) @@ -44,29 +56,31 @@ def from_alignment(cls, align_fh, gene): @classmethod # this is slow - called only if coverage is requested def _import_coverage(cls, align_fh, reg): - delta = np.zeros(reg[2]-reg[1]) + delta = np.zeros(reg[2] - reg[1]) junctions = {} for read in align_fh.fetch(*reg): exons = junctions_from_cigar(read.cigartuples, read.reference_start) # alternative: read.get_blocks() should be more efficient.. -todo: is it different? for i, exon in enumerate(exons): - s = max(reg[1], min(reg[2]-1, exon[0]))-reg[1] - e = max(reg[1], min(reg[2]-1, exon[1]))-reg[1] + s = max(reg[1], min(reg[2] - 1, exon[0])) - reg[1] + e = max(reg[1], min(reg[2] - 1, exon[1])) - reg[1] delta[s] += 1 delta[e] -= 1 if i > 0: - jpos = (exons[i-1][1], exon[0]) - if jpos[1]-jpos[0] < 1 or jpos[0] < reg[1] or jpos[1] > reg[2]: + jpos = (exons[i - 1][1], exon[0]) + if jpos[1] - jpos[0] < 1 or jpos[0] < reg[1] or jpos[1] > reg[2]: continue - junctions[jpos] = junctions.get(jpos, 0)+1 + junctions[jpos] = junctions.get(jpos, 0) + 1 cov = np.cumsum(delta) # todo: use rle instead? return cov, junctions def load(self): - 'load the coverage from bam file' - with AlignmentFile(self.bam_fn, 'rb') as align: - logger.debug(f'Illumina coverage of region {self.reg[0]}:{self.reg[1]}-{self.reg[2]} is loaded from {self.bam_fn}') # info or debug? + "load the coverage from bam file" + with AlignmentFile(self.bam_fn, "rb") as align: + logger.debug( + f"Illumina coverage of region {self.reg[0]}:{self.reg[1]}-{self.reg[2]} is loaded from {self.bam_fn}" + ) # info or debug? self._cov, self._junctions = type(self)._import_coverage(align, self.reg) @property @@ -83,11 +97,15 @@ def profile(self): def __getitem__(self, subscript): if isinstance(subscript, slice): - return self.profile[slice(None if subscript.start is None else subscript.start-self.reg[1], - None if subscript.stop is None else subscript.stop-self.reg[1], - subscript.step)] # does not get extended if outside range + return self.profile[ + slice( + None if subscript.start is None else subscript.start - self.reg[1], + None if subscript.stop is None else subscript.stop - self.reg[1], + subscript.step, + ) + ] # does not get extended if outside range elif subscript < self.reg[1] or subscript >= self.reg[2]: - logger.warning('requested coverage outside range') + logger.warning("requested coverage outside range") return None else: - return self.profile[subscript-self.reg[1]] + return self.profile[subscript - self.reg[1]] diff --git a/src/isotools/splice_graph.py b/src/isotools/splice_graph.py index ef917f0..58056b6 100755 --- a/src/isotools/splice_graph.py +++ b/src/isotools/splice_graph.py @@ -7,10 +7,10 @@ from .decorators import deprecated, experimental from typing import Generator, Literal, Optional, Union -logger = logging.getLogger('isotools') +logger = logging.getLogger("isotools") -class SegmentGraph(): +class SegmentGraph: '''Segment Graph Implementation Nodes in the Segment Graph represent disjoint exonic bins (aka segments) and have start (genomic 5'), end (genomic 3'), @@ -22,16 +22,18 @@ class SegmentGraph(): :type transcripts: list :param strand: the strand of the gene, either "+" or "-"''' - strand: Literal['+', '-'] + strand: Literal["+", "-"] _graph: list[SegGraphNode] _tss: list[int] "List of start-nodes for each transcript. TODO: Name is misleading. If the strand is '-', this is actually the PAS." _pas: list[int] "List of end-nodes for each transcript. TODO: Name is misleading. If the strand is '-', this is actually the TSS." - def __init__(self, transcript_exons: list[list[tuple[int, int]]], strand: Literal['+', '-']): + def __init__( + self, transcript_exons: list[list[tuple[int, int]]], strand: Literal["+", "-"] + ): self.strand = strand - assert strand in '+-', 'strand must be either "+" or "-"' + assert strand in "+-", 'strand must be either "+" or "-"' open_exons: dict[int, int] = dict() for exons in transcript_exons: for exon in exons: @@ -68,12 +70,12 @@ def __init__(self, transcript_exons: list[list[tuple[int, int]]], strand: Litera self._graph[start_idx[exon2[0]]].pre[i] = end_idx[exon[1]] def _restore(self, i: int) -> list: # mainly for testing - ''' Restore the i_{th} transcript from the Segment graph by traversing from 5' to 3' + """Restore the i_{th} transcript from the Segment graph by traversing from 5' to 3' :param i: The index of the transcript to restore :type i: int :return: A list of exon tuples representing the transcript - :rtype: list''' + :rtype: list""" idx = self._tss[i] exons = [[self._graph[idx].start, self._graph[idx].end]] while True: @@ -89,12 +91,12 @@ def _restore(self, i: int) -> list: # mainly for testing return exons def _restore_reverse(self, i: int) -> list: # mainly for testing - ''' Restore the ith transcript from the Segment graph by traversing from 3' to 5' + """Restore the ith transcript from the Segment graph by traversing from 3' to 5' :param i: The index of the transcript to restore :type i: int :return: A list of exon tuples representing the transcript - :rtype: list''' + :rtype: list""" idx = self._pas[i] exons = [[self._graph[idx].start, self._graph[idx].end]] while True: @@ -112,42 +114,63 @@ def _restore_reverse(self, i: int) -> list: # mainly for testing @deprecated def search_transcript2(self, exons: list[tuple[int, int]]): - '''Tests if a transcript (provided as list of exons) is contained in self and return the corresponding transcript indices. + """Tests if a transcript (provided as list of exons) is contained in self and return the corresponding transcript indices. :param exons: A list of exon tuples representing the transcript :type exons: list :return: a list of supporting transcript indices - :rtype: list''' + :rtype: list""" # fst special case: exons extends segment graph if exons[0][1] <= self[0].start or exons[-1][0] >= self[-1].end: return [] # snd special case: single exon transcript: return all overlapping single exon transcripts form sg if len(exons) == 1: - return [transcript_id for transcript_id, (j1, j2) in enumerate(zip(self._tss, self._pas)) - if self._is_same_exon(transcript_id, j1, j2) and self[j1].start <= exons[0][1] and self[j2].end >= exons[0][0]] + return [ + transcript_id + for transcript_id, (j1, j2) in enumerate(zip(self._tss, self._pas)) + if self._is_same_exon(transcript_id, j1, j2) + and self[j1].start <= exons[0][1] + and self[j2].end >= exons[0][0] + ] # all junctions must be contained and no additional transcript = set(range(len(self._tss))) j = 0 for i, e in enumerate(exons[:-1]): - while j < len(self) and self[j].end < e[1]: # check exon (no junction allowed) - transcript -= set(transcript_id for transcript_id, j2 in self[j].suc.items() if self[j].end != self[j2].start) + while ( + j < len(self) and self[j].end < e[1] + ): # check exon (no junction allowed) + transcript -= set( + transcript_id + for transcript_id, j2 in self[j].suc.items() + if self[j].end != self[j2].start + ) j += 1 if self[j].end != e[1]: return [] # check junction (must be present) - transcript &= set(transcript_id for transcript_id, j2 in self[j].suc.items() if self[j2].start == exons[i + 1][0]) + transcript &= set( + transcript_id + for transcript_id, j2 in self[j].suc.items() + if self[j2].start == exons[i + 1][0] + ) j += 1 if len(transcript) == 0: return transcript while j < len(self): # check last exon (no junction allowed) - transcript -= set(transcript_id for transcript_id, j2 in self[j].suc.items() if self[j].end != self[j2].start) + transcript -= set( + transcript_id + for transcript_id, j2 in self[j].suc.items() + if self[j].end != self[j2].start + ) j += 1 return [transcript_id for transcript_id in transcript] - def search_transcript(self, exons: list[tuple[int, int]], complete=True, include_ends=False): - '''Tests if a transcript (provided as list of exons) is contained in sg and return the corresponding transcript indices. + def search_transcript( + self, exons: list[tuple[int, int]], complete=True, include_ends=False + ): + """Tests if a transcript (provided as list of exons) is contained in sg and return the corresponding transcript indices. Search the splice graph for transcripts that match the introns of the provided list of exons. @@ -158,7 +181,7 @@ def search_transcript(self, exons: list[tuple[int, int]], complete=True, include :param include_ends: If True, yield only splice graph transcripts that include the first and last exon. If False, also yield splice graph transcripts that extend first and/or last exon but match the intron chain. :return: a list of supporting transcript indices - :rtype: list''' + :rtype: list""" # fst special case: exons extends segment graph if include_ends: @@ -170,68 +193,133 @@ def search_transcript(self, exons: list[tuple[int, int]], complete=True, include # snd special case: single exon transcript: return all overlapping /including single exon transcripts form sg if len(exons) == 1 and complete: if include_ends: - return [transcript_id for transcript_id, (j1, j2) in enumerate(zip(self._tss, self._pas)) - if self._is_same_exon(transcript_id, j1, j2) and self[j1].start >= exons[0][0] and self[j2].end <= exons[0][1]] + return [ + transcript_id + for transcript_id, (j1, j2) in enumerate(zip(self._tss, self._pas)) + if self._is_same_exon(transcript_id, j1, j2) + and self[j1].start >= exons[0][0] + and self[j2].end <= exons[0][1] + ] else: - return [transcript_id for transcript_id, (j1, j2) in enumerate(zip(self._tss, self._pas)) - if self._is_same_exon(transcript_id, j1, j2) and self[j1].start <= exons[0][1] and self[j2].end >= exons[0][0]] + return [ + transcript_id + for transcript_id, (j1, j2) in enumerate(zip(self._tss, self._pas)) + if self._is_same_exon(transcript_id, j1, j2) + and self[j1].start <= exons[0][1] + and self[j2].end >= exons[0][0] + ] # j is index of last overlapping node - j = next((i for i, n in enumerate(self) if n.start >= exons[0][1]), len(self))-1 + j = ( + next((i for i, n in enumerate(self) if n.start >= exons[0][1]), len(self)) + - 1 + ) if self[j].end > exons[0][1] and len(exons) > 1: return [] if complete: # for include_ends we need to find first node of exon # j_first is index of first node from the first exon, - j_first = next((i for i in range(j, 0, -1) if self[i-1].end < self[i].start), 0) - transcript = [transcript_id for transcript_id, i in enumerate(self._tss) if (not include_ends or self[i].start <= exons[0][0]) - and j_first <= i <= j and self._is_same_exon(transcript_id, i, j)] + j_first = next( + (i for i in range(j, 0, -1) if self[i - 1].end < self[i].start), 0 + ) + transcript = [ + transcript_id + for transcript_id, i in enumerate(self._tss) + if (not include_ends or self[i].start <= exons[0][0]) + and j_first <= i <= j + and self._is_same_exon(transcript_id, i, j) + ] else: if include_ends: - j_first = next((i for i in range(j, 0, -1) if self[i].start <= exons[0][0]), 0) - transcript = [transcript_id for transcript_id in self[j].pre if self._is_same_exon(transcript_id, j_first, j)] + j_first = next( + (i for i in range(j, 0, -1) if self[i].start <= exons[0][0]), 0 + ) + transcript = [ + transcript_id + for transcript_id in self[j].pre + if self._is_same_exon(transcript_id, j_first, j) + ] if j_first == j: - transcript += [transcript_id for transcript_id, i in enumerate(self._tss) if i == j_first] + transcript += [ + transcript_id + for transcript_id, i in enumerate(self._tss) + if i == j_first + ] else: - transcript = list(self[j].pre)+[transcript_id for transcript_id, i in enumerate(self._tss) if i == j] + transcript = list(self[j].pre) + [ + transcript_id for transcript_id, i in enumerate(self._tss) if i == j + ] # all junctions must be contained and no additional for i, e in enumerate(exons[:-1]): - while j < len(self) and self[j].end < e[1]: # check exon (no junction allowed) - transcript = [transcript_id for transcript_id in transcript if transcript_id in self[j].suc and self[j].end == self[self[j].suc[transcript_id]].start] + while ( + j < len(self) and self[j].end < e[1] + ): # check exon (no junction allowed) + transcript = [ + transcript_id + for transcript_id in transcript + if transcript_id in self[j].suc + and self[j].end == self[self[j].suc[transcript_id]].start + ] j += 1 if self[j].end != e[1]: return [] # check junction (must be present) - transcript = [transcript_id for transcript_id in transcript if transcript_id in self[j].suc and self[self[j].suc[transcript_id]].start == exons[i+1][0]] + transcript = [ + transcript_id + for transcript_id in transcript + if transcript_id in self[j].suc + and self[self[j].suc[transcript_id]].start == exons[i + 1][0] + ] if not transcript: return [] j = self[j].suc[transcript[0]] if include_ends: while self[j].end < exons[-1][1]: - transcript = [transcript_id for transcript_id in transcript if transcript_id in self[j].suc and self[j].end == self[self[j].suc[transcript_id]].start] + transcript = [ + transcript_id + for transcript_id in transcript + if transcript_id in self[j].suc + and self[j].end == self[self[j].suc[transcript_id]].start + ] j += 1 if not complete: return transcript # ensure that all transcripts end (no junctions allowed) while j < len(self): # check last exon (no junction allowed) - transcript = [transcript_id for transcript_id in transcript if transcript_id not in self[j].suc or self[j].end == self[self[j].suc[transcript_id]].start] + transcript = [ + transcript_id + for transcript_id in transcript + if transcript_id not in self[j].suc + or self[j].end == self[self[j].suc[transcript_id]].start + ] j += 1 return transcript def _is_same_exon(self, transcript_number, j1, j2): - '''Tests if nodes j1 and j2 belong to same exon in transcript transcript_number.''' + """Tests if nodes j1 and j2 belong to same exon in transcript transcript_number.""" for j in range(j1, j2): - if transcript_number not in self[j].suc or self[j].suc[transcript_number] > j + 1 or self[j].end != self[j + 1].start: + if ( + transcript_number not in self[j].suc + or self[j].suc[transcript_number] > j + 1 + or self[j].end != self[j + 1].start + ): return False return True def _count_introns(self, transcript_number, j1, j2): - '''Counts the number of junctions between j1 and j2.''' - logger.debug('counting introns of transcript %i between nodes %i and %i', transcript_number, j1, j2) + """Counts the number of junctions between j1 and j2.""" + logger.debug( + "counting introns of transcript %i between nodes %i and %i", + transcript_number, + j1, + j2, + ) delta = 0 if j1 == j2: return 0 - assert transcript_number in self[j1].suc, f'transcript {transcript_number} does not contain node {j1}' + assert ( + transcript_number in self[j1].suc + ), f"transcript {transcript_number} does not contain node {j1}" while j1 < j2: j_next = self[j1].suc[transcript_number] if j_next > j1 + 1 or self[j1].end != self[j1 + 1].start: @@ -240,33 +328,54 @@ def _count_introns(self, transcript_number, j1, j2): return delta def get_node_matrix(self) -> np.array: - '''Gets the node matrix representation of the segment graph.''' - return np.array([[tss == j or transcript_id in n.pre for j, n in enumerate(self)] for transcript_id, tss in enumerate(self._tss)]) + """Gets the node matrix representation of the segment graph.""" + return np.array( + [ + [tss == j or transcript_id in n.pre for j, n in enumerate(self)] + for transcript_id, tss in enumerate(self._tss) + ] + ) def find_fragments(self): - '''Finds all fragments (e.g. transcript contained in other transcripts) in the segment graph.''' + """Finds all fragments (e.g. transcript contained in other transcripts) in the segment graph.""" truncated: set[int] = set() contains: dict[int, set[int]] = {} nodes = self.get_node_matrix() for transcript_id, (tss, pas) in enumerate(zip(self._tss, self._pas)): if transcript_id in truncated: continue - contains[transcript_id] = {transcript_id2 for transcript_id2, (tss2, pas2) in enumerate(zip(self._tss, self._pas)) - if transcript_id2 != transcript_id and tss2 >= tss and pas2 <= pas and - all(nodes[transcript_id2, tss2:pas2 + 1] == nodes[transcript_id, tss2:pas2 + 1])} + contains[transcript_id] = { + transcript_id2 + for transcript_id2, (tss2, pas2) in enumerate(zip(self._tss, self._pas)) + if transcript_id2 != transcript_id + and tss2 >= tss + and pas2 <= pas + and all( + nodes[transcript_id2, tss2 : pas2 + 1] + == nodes[transcript_id, tss2 : pas2 + 1] + ) + } truncated.update(contains[transcript_id]) # those are not checked fragments = {} for big, smallL in contains.items(): if big not in truncated: for transcript_id in smallL: - delta1 = self._count_introns(big, self._tss[big], self._tss[transcript_id]) - delta2 = self._count_introns(big, self._pas[transcript_id], self._pas[big]) - fragments.setdefault(transcript_id, []).append((big, delta1, delta2) if self.strand == '+' else (big, delta2, delta1)) + delta1 = self._count_introns( + big, self._tss[big], self._tss[transcript_id] + ) + delta2 = self._count_introns( + big, self._pas[transcript_id], self._pas[big] + ) + fragments.setdefault(transcript_id, []).append( + (big, delta1, delta2) + if self.strand == "+" + else (big, delta2, delta1) + ) return fragments def get_alternative_splicing(self, exons: list[tuple[int, int]], alternative=None): - '''Compares exons to segment graph and returns list of novel splicing events. + """Compares exons to segment graph and returns list of novel splicing events. This function computes the novelty class of the provided transcript compared to (reference annotation) transcripts from the segment graph. It returns the "squanti category" (0=FSM,1=ISM,2=NIC,3=NNC,4=Novel gene) and the subcategory. @@ -276,7 +385,7 @@ def get_alternative_splicing(self, exons: list[tuple[int, int]], alternative=Non :param alternative: list of splice site indices that match other genes :return: pair with the squanti category number and the subcategories as list of novel splicing events that produce the provided transcript from the transcripts in splice graph - :rtype: tuple''' + :rtype: tuple""" # returns a tuple # the sqanti category: 0=FSM,1=ISM,2=NIC,3=NNC,4=Novel gene @@ -286,42 +395,63 @@ def get_alternative_splicing(self, exons: list[tuple[int, int]], alternative=Non if alternative is not None and len(alternative) > 0: category = 4 fusion_exons = {int((i + 1) / 2) for j in alternative for i in j[1]} - altsplice = {'readthrough fusion': alternative} # other novel events are only found in the primary reference transcript + altsplice = { + "readthrough fusion": alternative + } # other novel events are only found in the primary reference transcript else: transcript = self.search_transcript(exons) if transcript: - return 0, {'FSM': transcript} + return 0, {"FSM": transcript} category = 1 altsplice = {} fusion_exons = set() - is_reverse = self.strand == '-' + is_reverse = self.strand == "-" j1 = next((j for j, n in enumerate(self) if n.end > exons[0][0])) # j1: index of first segment ending after exon start (i.e. first overlapping segment) - j2 = next((j - 1 for j in range(j1, len(self)) if self[j].start >= exons[0][1]), len(self) - 1) + j2 = next( + (j - 1 for j in range(j1, len(self)) if self[j].start >= exons[0][1]), + len(self) - 1, + ) # j2: index of last segment starting before exon end (i.e. last overlapping segment) # check truncation at begining (e.g. low position) - if (len(exons) > 1 and # no mono exon - not any(j in self._tss for j in range(j1, j2 + 1)) and # no tss/pas within exon - self[j1].start <= exons[0][0]): # start of first exon is exonic in ref - j0 = max(self._tss[transcript_id] for transcript_id in self[j1].pre) # j0 is the closest start node - if any(self[j].end < self[j + 1].start for j in range(j0, j1)): # assure there is an intron between closest tss/pas and exon - end = '5' if is_reverse else '3' - altsplice.setdefault(f'{end}\' fragment', []).append([self[j0].start, exons[0][0]]) # at start (lower position) + if ( + len(exons) > 1 # no mono exon + and not any( + j in self._tss for j in range(j1, j2 + 1) + ) # no tss/pas within exon + and self[j1].start <= exons[0][0] + ): # start of first exon is exonic in ref + j0 = max( + self._tss[transcript_id] for transcript_id in self[j1].pre + ) # j0 is the closest start node + if any( + self[j].end < self[j + 1].start for j in range(j0, j1) + ): # assure there is an intron between closest tss/pas and exon + end = "5" if is_reverse else "3" + altsplice.setdefault(f"{end}' fragment", []).append( + [self[j0].start, exons[0][0]] + ) # at start (lower position) for i, ex1 in enumerate(exons): ex2 = None if i + 1 == len(exons) else exons[i + 1] - if i not in fusion_exons: # exon belongs to other gene (read through fusion) + if ( + i not in fusion_exons + ): # exon belongs to other gene (read through fusion) # finds intron retention (NIC), novel exons, novel splice sites, novel pas/tss (NNC) - exon_altsplice, exon_cat = self._check_exon(j1, j2, i == 0, is_reverse, ex1, ex2) + exon_altsplice, exon_cat = self._check_exon( + j1, j2, i == 0, is_reverse, ex1, ex2 + ) category = max(exon_cat, category) for k, v in exon_altsplice.items(): altsplice.setdefault(k, []).extend(v) # find j2: index of last segment starting befor exon2 end (i.e. last overlapping segment) if ex2 is not None: if j2 + 1 < len(self): - j1, j2, junction_altsplice = self._check_junction(j1, j2, ex1, ex2) # finds exon skipping and novel junction (NIC) + j1, j2, junction_altsplice = self._check_junction( + j1, j2, ex1, ex2 + ) # finds exon skipping and novel junction (NIC) if junction_altsplice and i + 1 not in fusion_exons: category = max(2, category) for k, v in junction_altsplice.items(): @@ -330,106 +460,185 @@ def get_alternative_splicing(self, exons: list[tuple[int, int]], alternative=Non j1 = len(self) # check truncation at end (e.g. high position) - if (len(exons) > 1 and - j2 >= j1 and - not any(j in self._pas for j in range(j1, j2 + 1)) and # no tss/pas within exon - self[j2].end >= exons[-1][1]): # end of last exon is exonic in ref + if ( + len(exons) > 1 + and j2 >= j1 + and not any( + j in self._pas for j in range(j1, j2 + 1) + ) # no tss/pas within exon + and self[j2].end >= exons[-1][1] + ): # end of last exon is exonic in ref try: - j3 = min(self._pas[transcript_id] for transcript_id in self[j2].suc) # j3 is the next end node (pas/tss on fwd/rev) + j3 = min( + self._pas[transcript_id] for transcript_id in self[j2].suc + ) # j3 is the next end node (pas/tss on fwd/rev) except ValueError: - logger.error('\n'.join([str(exons), str(self._pas), str((j1, j2)), str([(j, n) for j, n in enumerate(self)])])) + logger.error( + "\n".join( + [ + str(exons), + str(self._pas), + str((j1, j2)), + str([(j, n) for j, n in enumerate(self)]), + ] + ) + ) raise - if any(self[j].end < self[j + 1].start for j in range(j2, j3)): # assure there is an intron between closest tss/pas and exon - end = '3' if is_reverse else '5' - altsplice.setdefault(f'{end}\' fragment', []).append([exons[-1][1], self[j3].end]) + if any( + self[j].end < self[j + 1].start for j in range(j2, j3) + ): # assure there is an intron between closest tss/pas and exon + end = "3" if is_reverse else "5" + altsplice.setdefault(f"{end}' fragment", []).append( + [exons[-1][1], self[j3].end] + ) if not altsplice: # all junctions are contained but not all in one transcript - altsplice = {'novel combination': []} + altsplice = {"novel combination": []} category = 2 return category, altsplice - def _check_exon(self, j1, j2, is_first, is_reverse, exon: tuple[int, int], exon2=None): - '''checks whether exon is supported by splice graph between nodes j1 and j2 + def _check_exon( + self, j1, j2, is_first, is_reverse, exon: tuple[int, int], exon2=None + ): + """checks whether exon is supported by splice graph between nodes j1 and j2 :param j1: index of first segment ending after exon start (i.e. first overlapping segment) - :param j2: index of last segment starting before exon end (i.e. last overlapping segment)''' - - logger.debug('exon %s between sg node %s and %s/%s (first=%s,rev=%s,e2=%s)', exon, j1, j2, len(self), is_first, is_reverse, exon2) + :param j2: index of last segment starting before exon end (i.e. last overlapping segment) + """ + + logger.debug( + "exon %s between sg node %s and %s/%s (first=%s,rev=%s,e2=%s)", + exon, + j1, + j2, + len(self), + is_first, + is_reverse, + exon2, + ) is_last = exon2 is None altsplice = {} category = 0 - if j1 > j2: # exon is not contained at all -> novel exon (or TSS/PAS if first/last) + if ( + j1 > j2 + ): # exon is not contained at all -> novel exon (or TSS/PAS if first/last) category = 3 if is_first or is_last: - altsplice = {'novel intronic PAS' if is_first == is_reverse else 'novel intronic TSS': [exon]} + altsplice = { + ( + "novel intronic PAS" + if is_first == is_reverse + else "novel intronic TSS" + ): [exon] + } else: - altsplice = {'novel exon': [exon]} + altsplice = {"novel exon": [exon]} j2 = j1 - elif (is_first and is_last): # mono-exon (should not overlap a reference monoexon transcript, this is caught earlier) - altsplice['mono-exon'] = [] + elif ( + is_first and is_last + ): # mono-exon (should not overlap a reference monoexon transcript, this is caught earlier) + altsplice["mono-exon"] = [] category = 1 else: # check splice sites if self[j1][0] != exon[0]: # first splice site missmatch if not is_first: # pos="intronic" if self[j1][0]>e[0] else "exonic" - kind = '5' if is_reverse else '3' - dist = min((self[j][0] - exon[0] for j in range(j1, j2 + 1)), key=abs) # the distance to next junction + kind = "5" if is_reverse else "3" + dist = min( + (self[j][0] - exon[0] for j in range(j1, j2 + 1)), key=abs + ) # the distance to next junction altsplice[f"novel {kind}' splice site"] = [(exon[0], dist)] category = 3 - elif self[j1][0] > exon[0] and not any(j in self._tss for j in range(j1, j2 + 1)): # exon start is intronic in ref - site = 'PAS' if is_reverse else 'TSS' - altsplice.setdefault(f'novel exonic {site}', []).append((exon[0], self[j1][0])) + elif self[j1][0] > exon[0] and not any( + j in self._tss for j in range(j1, j2 + 1) + ): # exon start is intronic in ref + site = "PAS" if is_reverse else "TSS" + altsplice.setdefault(f"novel exonic {site}", []).append( + (exon[0], self[j1][0]) + ) category = max(1, category) if self[j2][1] != exon[1]: # second splice site missmatch if not is_last: # pos="intronic" if self[j2][1] 0 for ji in range(j1, j2)): + if j1 < j2 and any( + self[ji + 1].start - self[ji].end > 0 for ji in range(j1, j2) + ): gaps = [ji for ji in range(j1, j2) if self[ji + 1].start - self[ji].end > 0] - if (gaps - and not (is_first and any(j in self._tss for j in range(gaps[-1] + 1, j2))) - and not (is_last and any(j in self._pas for j in range(j1, gaps[0] + 1)))): + if ( + gaps + and not ( + is_first and any(j in self._tss for j in range(gaps[-1] + 1, j2)) + ) + and not ( + is_last and any(j in self._pas for j in range(j1, gaps[0] + 1)) + ) + ): ret_introns = [] troi = set(self[j1].suc.keys()).intersection(self[j2].pre.keys()) if troi: j = j1 while j < j2: - nextj = min(js for transcript_id, js in self[j].suc.items() if transcript_id in troi) - if self[nextj].start - self[j].end > 0 and any(self[ji + 1].start - self[ji].end > 0 for ji in range(j, nextj)): + nextj = min( + js + for transcript_id, js in self[j].suc.items() + if transcript_id in troi + ) + if self[nextj].start - self[j].end > 0 and any( + self[ji + 1].start - self[ji].end > 0 + for ji in range(j, nextj) + ): ret_introns.append((self[j].end, self[nextj].start)) j = nextj if ret_introns: - altsplice['intron retention'] = ret_introns + altsplice["intron retention"] = ret_introns category = max(2, category) - logger.debug('check exon %s resulted in %s', exon, altsplice) + logger.debug("check exon %s resulted in %s", exon, altsplice) return altsplice, category def _check_junction(self, j1, j2, e, e2): - ''' check a junction in the segment graph + """check a junction in the segment graph * check presence e1-e2 junction in ref (-> if not exon skipping or novel junction) * presence is defined a direct junction from an ref exon (e.g. from self) overlapping e1 to an ref exon overlapping e2 * AND find j3 and j4: first node overlapping e2 and last node overlapping e2 * more specifically: * j3: first node ending after e2 start, or len(self) - * j4: last node starting before e2 end (assuming there is such a node)''' + * j4: last node starting before e2 end (assuming there is such a node)""" altsplice = {} - j3 = next((j for j in range(j2 + 1, len(self)) if self[j][1] > e2[0]), len(self)) - j4 = next((j - 1 for j in range(j3, len(self)) if self[j].start >= e2[1]), len(self) - 1) + j3 = next( + (j for j in range(j2 + 1, len(self)) if self[j][1] > e2[0]), len(self) + ) + j4 = next( + (j - 1 for j in range(j3, len(self)) if self[j].start >= e2[1]), + len(self) - 1, + ) if j3 == len(self) or self[j3].start > e2[1]: return j3, j4, altsplice # no overlap with e2 - if e[1] == self[j2].end and e2[0] == self[j3].start and j3 in self[j2].suc.values(): + if ( + e[1] == self[j2].end + and e2[0] == self[j3].start + and j3 in self[j2].suc.values() + ): return j3, j4, altsplice # found direct junction # find skipped exons within e1-e2 intron exon_skipping = set() @@ -463,16 +672,20 @@ def _check_junction(self, j1, j2, e, e2): exons.append([e_start, e_end]) e_start = 0 if len(exons) > 1: - altsplice.setdefault('exon skipping', []).extend(exons[1:]) - elif e[1] == self[j2].end and e2[0] == self[j3].start: # e1-e2 path is not present, but splice sites are - altsplice.setdefault('novel junction', []).append([e[1], e2[0]]) # for example mutually exclusive exons spliced togeter + altsplice.setdefault("exon skipping", []).extend(exons[1:]) + elif ( + e[1] == self[j2].end and e2[0] == self[j3].start + ): # e1-e2 path is not present, but splice sites are + altsplice.setdefault("novel junction", []).append( + [e[1], e2[0]] + ) # for example mutually exclusive exons spliced togeter - logger.debug('check junction %s - %s resulted in %s', e[0], e[1], altsplice) + logger.debug("check junction %s - %s resulted in %s", e[0], e[1], altsplice) return j3, j4, altsplice def fuzzy_junction(self, exons: list[tuple[int, int]], size: int): - '''Looks for "fuzzy junctions" in the provided transcript. + """Looks for "fuzzy junctions" in the provided transcript. For each intron from "exons", look for introns in the splice graph shifted by less than "size". These shifts may be produced by ambigious alignments. @@ -482,7 +695,7 @@ def fuzzy_junction(self, exons: list[tuple[int, int]], size: int): :param size: The maximum size of the fuzzy junction :type size: int :return: a dict with the intron number as key and the shift as value (assuming size is smaller than introns) - :rtype: dict''' + :rtype: dict""" fuzzy = {} if size < 1: # no need to check return fuzzy @@ -497,24 +710,31 @@ def fuzzy_junction(self, exons: list[tuple[int, int]], size: int): except (StopIteration, IndexError): # transcript end - we are done break shift = [] - while j1 < len(self) and self[j1].end - exon1[1] <= min(size, exon1[1] - exon1[0]): # in case there are several nodes starting in the range around e1 + while j1 < len(self) and self[j1].end - exon1[1] <= min( + size, exon1[1] - exon1[0] + ): # in case there are several nodes starting in the range around e1 shift_e1 = self[j1].end - exon1[1] # print(f'{i} {e1[1]}-{e2[0]} {shift_e1}') if shift_e1 == 0: # no shift required at this intron break - if any(self[j2].start - exon2[0] == shift_e1 for j2 in set(self[j1].suc.values())): + if any( + self[j2].start - exon2[0] == shift_e1 + for j2 in set(self[j1].suc.values()) + ): shift.append(shift_e1) j1 += 1 else: # junction not found in sg if shift: # but shifted juction is present - fuzzy[i] = sorted(shift, key=abs)[0] # if there are several possible shifts, provide the smallest + fuzzy[i] = sorted(shift, key=abs)[ + 0 + ] # if there are several possible shifts, provide the smallest return fuzzy def find_splice_sites(self, splice_junctions: list[tuple[int, int]]): - '''Checks whether the splice sites of a new transcript are present in the segment graph. + """Checks whether the splice sites of a new transcript are present in the segment graph. :param splice_junctions: A list of 2-tuples with the splice site positions - :return: boolean array indicating whether the splice site is contained or not''' + :return: boolean array indicating whether the splice site is contained or not""" sites = np.zeros(len(splice_junctions) * 2, dtype=bool) splice_junction_starts = {} @@ -545,11 +765,12 @@ def find_splice_sites(self, splice_junctions: list[tuple[int, int]]): return sites def get_overlap(self, exons): - '''Compute the exonic overlap of a new transcript with the segment graph. + """Compute the exonic overlap of a new transcript with the segment graph. :param exons: A list of exon tuples representing the transcript :type exons: list - :return: a tuple: the overlap with the gene, and a list of the overlaps with the transcripts''' + :return: a tuple: the overlap with the gene, and a list of the overlaps with the transcripts + """ ol = 0 j = 0 @@ -562,12 +783,12 @@ def get_overlap(self, exons): while self[j].start < e[1]: i_end = min(e[1], self[j].end) i_start = max(e[0], self[j].start) - ol += (i_end - i_start) + ol += i_end - i_start for transcript_id in self[j].suc.keys(): - transcript_overlap[transcript_id] += (i_end - i_start) + transcript_overlap[transcript_id] += i_end - i_start for transcript_id, pas in enumerate(self._pas): if pas == j: - transcript_overlap[transcript_id] += (i_end - i_start) + transcript_overlap[transcript_id] += i_end - i_start if self[j].end > e[1]: break j += 1 @@ -577,16 +798,19 @@ def get_overlap(self, exons): return ol, transcript_overlap def get_intron_support_matrix(self, exons): - '''Check the intron support for the provided transcript w.r.t. transcripts from self. + """Check the intron support for the provided transcript w.r.t. transcripts from self. This is supposed to be helpful for the analysis of novel combinations of known splice sites. :param exons: A list of exon positions defining the transcript to check. :return: A boolean array of shape (n_transcripts in self)x(len(exons)-1). - An entry is True iff the intron from "exons" is present in the respective transcript of self.''' + An entry is True iff the intron from "exons" is present in the respective transcript of self. + """ node_iter = iter(self) - ism = np.zeros((len(self._tss), len(exons) - 1), bool) # the intron support matrix + ism = np.zeros( + (len(self._tss), len(exons) - 1), bool + ) # the intron support matrix for intron_nr, (e1, e2) in enumerate(pairwise(exons)): try: node = next(n for n in node_iter if n.end >= e1[1]) @@ -599,7 +823,7 @@ def get_intron_support_matrix(self, exons): return ism def get_exon_support_matrix(self, exons): - '''Check the exon support for the provided transcript w.r.t. transcripts from self. + """Check the exon support for the provided transcript w.r.t. transcripts from self. This is supposed to be helpful for the analysis of novel combinations of known splice sites. @@ -607,20 +831,33 @@ def get_exon_support_matrix(self, exons): :param exons: A list of exon positions defining the transcript to check. :return: A boolean array of shape (n_transcripts in self)x(len(exons)-1). An entry is True iff the exon from "exons" is fully covered in the respective transcript of self. - First and last exon are checked to overlap the first and last exon of the ref transcript but do not need to be fully covered''' + First and last exon are checked to overlap the first and last exon of the ref transcript but do not need to be fully covered + """ esm = np.zeros((len(self._tss), len(exons)), bool) # the intron support matrix - for transcript_number, tss in enumerate(self._tss): # check overlap of first exon + for transcript_number, tss in enumerate( + self._tss + ): # check overlap of first exon for j in range(tss, len(self)): if has_overlap(self[j], exons[0]): esm[transcript_number, 0] = True - elif self[j].suc.get(transcript_number, None) == j + 1 and j - 1 < len(self) and self[j].end == self[j + 1].start: + elif ( + self[j].suc.get(transcript_number, None) == j + 1 + and j - 1 < len(self) + and self[j].end == self[j + 1].start + ): continue break - for transcript_number, pas in enumerate(self._pas): # check overlap of last exon + for transcript_number, pas in enumerate( + self._pas + ): # check overlap of last exon for j in range(pas, -1, -1): if has_overlap(self[j], exons[-1]): esm[transcript_number, -1] = True - elif self[j].pre.get(transcript_number, None) == j - 1 and j > 0 and self[j].start == self[j - 1].end: + elif ( + self[j].pre.get(transcript_number, None) == j - 1 + and j > 0 + and self[j].start == self[j - 1].end + ): continue break @@ -628,10 +865,15 @@ def get_exon_support_matrix(self, exons): for e_nr, e in enumerate(exons[1:-1]): j1 = next((j for j in range(j2, len(self)) if self[j].end > e[0])) # j1: index of first segment ending after exon start (i.e. first overlapping segment) - j2 = next((j - 1 for j in range(j1, len(self)) if self[j].start >= e[1]), len(self) - 1) + j2 = next( + (j - 1 for j in range(j1, len(self)) if self[j].start >= e[1]), + len(self) - 1, + ) # j2: index of last segment starting befor exon end (i.e. last overlapping segment) if self[j1].start <= e[0] and self[j2].end >= e[1]: - covered = set.intersection(*(set(self[j].suc) for j in range(j1, j2 + 1))) + covered = set.intersection( + *(set(self[j].suc) for j in range(j1, j2 + 1)) + ) if covered: esm[covered, e_nr + 1] = True return esm @@ -646,21 +888,29 @@ def get_exonic_region(self): return regs def get_intersects(self, exons): - '''Computes the splice junction exonic overlap of a new transcript with the segment graph. + """Computes the splice junction exonic overlap of a new transcript with the segment graph. :param exons: A list of exon tuples representing the transcript :type exons: list - :return: the splice junction overlap and exonic overlap''' + :return: the splice junction overlap and exonic overlap""" intersect = [0, 0] i = j = 0 while True: - if self[j][0] == exons[i][0] and any(self[k][1] < self[j][0] for k in self[j].pre.values()): - intersect[0] += 1 # same position and actual splice junction(not just tss or pas and internal junction) - if self[j][1] == exons[i][1] and any(self[k][0] > self[j][1] for k in self[j].suc.values()): + if self[j][0] == exons[i][0] and any( + self[k][1] < self[j][0] for k in self[j].pre.values() + ): + intersect[ + 0 + ] += 1 # same position and actual splice junction(not just tss or pas and internal junction) + if self[j][1] == exons[i][1] and any( + self[k][0] > self[j][1] for k in self[j].suc.values() + ): intersect[0] += 1 if self[j][1] > exons[i][0] and exons[i][1] > self[j][0]: # overlap - intersect[1] += min(self[j][1], exons[i][1]) - max(self[j][0], exons[i][0]) + intersect[1] += min(self[j][1], exons[i][1]) - max( + self[j][0], exons[i][0] + ) if exons[i][1] < self[j][1]: i += 1 else: @@ -670,10 +920,16 @@ def get_intersects(self, exons): @deprecated def _find_ts_candidates(self, coverage): - '''Computes a metric indicating template switching.''' + """Computes a metric indicating template switching.""" for i, gnode in enumerate(self._graph[:-1]): - if self._graph[i + 1].start == gnode.end: # jump candidates: introns that start within an exon - jumps = {idx: n for idx, n in gnode.suc.items() if n > i + 1 and self._graph[n].start == self._graph[n - 1].end} + if ( + self._graph[i + 1].start == gnode.end + ): # jump candidates: introns that start within an exon + jumps = { + idx: n + for idx, n in gnode.suc.items() + if n > i + 1 and self._graph[n].start == self._graph[n - 1].end + } # find jumps (n>i+1) and check wether they end within an exon begin(jumptarget)==end(node before) jump_weight = {} for idx, target in jumps.items(): @@ -682,7 +938,11 @@ def _find_ts_candidates(self, coverage): jump_weight[target][1].append(idx) for target, (w, idx) in jump_weight.items(): - long_idx = set(idx for idx, n in gnode.suc.items() if n == i + 1) & set(idx for idx, n in self[target].pre.items() if n == target - 1) + long_idx = set( + idx for idx, n in gnode.suc.items() if n == i + 1 + ) & set( + idx for idx, n in self[target].pre.items() if n == target - 1 + ) try: longer_weight = coverage[:, list(long_idx)].sum() except IndexError: @@ -691,20 +951,26 @@ def _find_ts_candidates(self, coverage): yield gnode.end, self[target].start, w, longer_weight, idx def _is_spliced(self, transcript_id, node_index1, node_index2): - 'checks if transcript is spliced (e.g. has an intron) between nodes ni1 and ni2' - if any(self[i].end < self[i + 1].start for i in range(node_index1, node_index2)): # all transcripts are spliced + "checks if transcript is spliced (e.g. has an intron) between nodes ni1 and ni2" + if any( + self[i].end < self[i + 1].start for i in range(node_index1, node_index2) + ): # all transcripts are spliced return True if all(transcript_id in self[i].suc for i in range(node_index1, node_index2)): return False return True def _get_next_spliced(self, transcript_id: int, node: int): - 'find the next spliced node for given transcript' + "find the next spliced node for given transcript" while node != self._pas[transcript_id]: try: - next_node = self[node].suc[transcript_id] # raises error if transcript_id not in node.suc + next_node = self[node].suc[ + transcript_id + ] # raises error if transcript_id not in node.suc except KeyError: - logger.error('transcript_id %s seems to be not in node %s', transcript_id, node) + logger.error( + "transcript_id %s seems to be not in node %s", transcript_id, node + ) raise if self[next_node].start > self[node].end: return next_node @@ -712,12 +978,16 @@ def _get_next_spliced(self, transcript_id: int, node: int): return None def _get_exon_end(self, transcript_id: int, node: int): - 'find the end of the exon to which node belongs for given transcript' + "find the end of the exon to which node belongs for given transcript" while node != self._pas[transcript_id]: try: - next_node = self[node].suc[transcript_id] # raises error if transcript_id not in node.suc + next_node = self[node].suc[ + transcript_id + ] # raises error if transcript_id not in node.suc except KeyError: - logger.error('transcript_id %s seems to be not in node %s', transcript_id, node) + logger.error( + "transcript_id %s seems to be not in node %s", transcript_id, node + ) raise if self[next_node].start > self[node].end: return node @@ -725,18 +995,22 @@ def _get_exon_end(self, transcript_id: int, node: int): return node def _get_exon_end_all(self, node: int): - 'find the end of the exon considering all transcripts' + "find the end of the exon considering all transcripts" while node < len(self) - 1 and self[node].end == self[node + 1].start: node += 1 return node def _get_exon_start(self, transcript_id: int, node: int): - 'find the start of the exon to which node belongs for given transcript' + "find the start of the exon to which node belongs for given transcript" while node != self._tss[transcript_id]: try: - next_node = self[node].pre[transcript_id] # raises error if transcript_id not in node.pre + next_node = self[node].pre[ + transcript_id + ] # raises error if transcript_id not in node.pre except KeyError: - logger.error('transcript_id %s seems to be not in node %s', transcript_id, node) + logger.error( + "transcript_id %s seems to be not in node %s", transcript_id, node + ) raise if self[next_node].end < self[node].start: return node @@ -744,36 +1018,45 @@ def _get_exon_start(self, transcript_id: int, node: int): return node def _get_exon_start_all(self, node): - 'find the start of the exon considering all transcripts' + "find the start of the exon considering all transcripts" while node > 0 and self[node - 1].end == self[node].start: node -= 1 return node def _find_splice_bubbles_at_position(self, types: list[ASEType], pos): - '''function to refind bubbles at a certain genomic position. + """function to refind bubbles at a certain genomic position. This turns out to be fundamentally different compared to iterating over all bubbles, hence it is a complete rewrite of the function. On the positive site, the functions can validate each other. I tried to reuse the variable names. - If both functions yield same results, there is a good chance that the complex code is actually right.''' + If both functions yield same results, there is a good chance that the complex code is actually right. + """ # TODO: format of pos isn't documented anywhere and the intended isn't clear from the code # more a comment than a docstring... - if any(type in ['ES', '3AS', '5AS', 'IR', 'ME'] for type in types): + if any(type in ["ES", "3AS", "5AS", "IR", "ME"] for type in types): try: i, node_A = self._get_node_ending_at(pos[0]) if len(pos) == 3: - middle = [next(idx for idx, node in enumerate(self[i:], i) if node.start > pos[1])] + middle = [ + next( + idx + for idx, node in enumerate(self[i:], i) + if node.start > pos[1] + ) + ] j, node_B = self._get_node_starting_at(pos[2], middle[0]) else: j, node_B = self._get_node_starting_at(pos[-1], i) middle = range(i + 2, j) except StopIteration as e: - raise ValueError(f"cannot find segments at {pos} in segment graph") from e + raise ValueError( + f"cannot find segments at {pos} in segment graph" + ) from e direct: set[int] = set() # primary indirect: dict[ASEType, set[int]] = { - 'ES': set(), - '3AS': set(), - '5AS': set(), - 'IR': set() + "ES": set(), + "3AS": set(), + "5AS": set(), + "IR": set(), } for transcript, node_id in node_A.suc.items(): if transcript not in node_B.pre: @@ -784,31 +1067,31 @@ def _find_splice_bubbles_at_position(self, types: list[ASEType], pos): five_prime = self[node_id].start == node_A.end three_prime = self[node_B.pre[transcript]].end == node_B.start if five_prime and three_prime: - type = 'IR' + type = "IR" elif five_prime: - type = '5AS' if self.strand == '+' else '3AS' + type = "5AS" if self.strand == "+" else "3AS" elif three_prime: - type = '3AS' if self.strand == '+' else '5AS' + type = "3AS" if self.strand == "+" else "5AS" else: - type = 'ES' + type = "ES" indirect[type].add(transcript) for type in types: - if type in ['ES', '3AS', '5AS', 'IR'] and direct and indirect[type]: + if type in ["ES", "3AS", "5AS", "IR"] and direct and indirect[type]: yield list(direct), list(indirect[type]), i, j, type - elif type == 'ME' and len(indirect['ES']) > 2: + elif type == "ME" and len(indirect["ES"]) > 2: me: list[ASEvent] = list() seen_alt = set() for middle_idx in middle: # alternative exons before the middle node, primary exons after the middle node (or the middle node itself) alt, prim = set(), set() - for transcript in indirect['ES']: + for transcript in indirect["ES"]: if node_B.pre[transcript] < middle_idx: alt.add(transcript) elif node_A.suc[transcript] >= middle_idx: prim.add(transcript) # make sure there is at least one new alt transcript with this middle node. if prim and alt - seen_alt: - me.append((list(prim), list(alt), i, j, 'ME')) + me.append((list(prim), list(alt), i, j, "ME")) seen_alt.update(alt) seen_prim = set() for me_event in reversed(me): @@ -816,31 +1099,54 @@ def _find_splice_bubbles_at_position(self, types: list[ASEType], pos): if me_event[0] - seen_prim: yield me_event seen_prim.update(me_event[0]) - if any(type in ['TSS', 'PAS'] for type in types): + if any(type in ["TSS", "PAS"] for type in types): try: i, _ = next((idx, n) for idx, n in enumerate(self) if n.start >= pos[0]) - j, _ = next(((idx, n) for idx, n in enumerate(self[i:], i) if n.end >= pos[-1]), (len(self) - 1, self[-1])) + j, _ = next( + ((idx, n) for idx, n in enumerate(self[i:], i) if n.end >= pos[-1]), + (len(self) - 1, self[-1]), + ) except StopIteration as e: - raise ValueError(f"cannot find segments at {pos} in segment graph") from e + raise ValueError( + f"cannot find segments at {pos} in segment graph" + ) from e - alt_types = ['TSS', 'PAS'] if self.strand == '+' else ['PAS', 'TSS'] + alt_types = ["TSS", "PAS"] if self.strand == "+" else ["PAS", "TSS"] # TODO: Second condition is always false, because node_B starts at pos[-1] if any(type == alt_types[0] for type in types) and node_B.end == pos[-1]: - alt = {transcript for transcript, tss in enumerate(self._tss) if i <= tss <= j and self._get_exon_end(transcript, tss) == j} + alt = { + transcript + for transcript, tss in enumerate(self._tss) + if i <= tss <= j and self._get_exon_end(transcript, tss) == j + } if alt: # find compatible alternatives: end after tss /start before pas - prim = [transcript for transcript, pas in enumerate(self._pas) if transcript not in alt and pas > j] # prim={transcript for transcript in range(len(self._tss)) if transcript not in alt} + prim = [ + transcript + for transcript, pas in enumerate(self._pas) + if transcript not in alt and pas > j + ] # prim={transcript for transcript in range(len(self._tss)) if transcript not in alt} if prim: yield list(prim), list(alt), i, j, alt_types[0] # TODO: Second condition is always false, because node_A ends at pos[0] if any(type == alt_types[1] for type in types) and node_A.start == pos[0]: - alt = {transcript for transcript, pas in enumerate(self._pas) if i <= pas <= j and self._get_exon_start(transcript, pas) == i} + alt = { + transcript + for transcript, pas in enumerate(self._pas) + if i <= pas <= j and self._get_exon_start(transcript, pas) == i + } if alt: - prim = [transcript for transcript, tss in enumerate(self._tss) if transcript not in alt and tss < i] + prim = [ + transcript + for transcript, tss in enumerate(self._tss) + if transcript not in alt and tss < i + ] if prim: yield list(prim), list(alt), i, j, alt_types[1] - def find_splice_bubbles(self, types: Optional[str | list[ASEType]] = None, pos=None): - '''Searches for alternative paths in the segment graph ("bubbles"). + def find_splice_bubbles( + self, types: Optional[str | list[ASEType]] = None, pos=None + ): + """Searches for alternative paths in the segment graph ("bubbles"). Bubbles are defined as combinations of nodes x_s and x_e with more than one path from x_s to x_e. @@ -851,19 +1157,25 @@ def find_splice_bubbles(self, types: Optional[str | list[ASEType]] = None, pos=N :return: Tuple with 1) transcript indices of primary (e.g. most direct) paths and 2) alternative paths respectively, as well as 3) start and 4) end node ids and 5) type of alternative event - ('ES', '3AS', '5AS', 'IR', 'ME', 'TSS', 'PAS')''' + ('ES', '3AS', '5AS', 'IR', 'ME', 'TSS', 'PAS')""" if types is None: - types: list[ASEType] = ('ES', '3AS', '5AS', 'IR', 'ME', 'TSS', 'PAS') + types: list[ASEType] = ("ES", "3AS", "5AS", "IR", "ME", "TSS", "PAS") elif isinstance(types, str): types = (types,) - alt_types: list[ASEType] = ('ES', '5AS', '3AS', 'IR', 'ME', 'PAS', "TSS") if self.strand == '-' else ('ES', '3AS', '5AS', 'IR', 'ME', 'TSS', "PAS") + alt_types: list[ASEType] = ( + ("ES", "5AS", "3AS", "IR", "ME", "PAS", "TSS") + if self.strand == "-" + else ("ES", "3AS", "5AS", "IR", "ME", "TSS", "PAS") + ) if pos is not None: - for prim, alt, i, j, alt_type in self._find_splice_bubbles_at_position(types, pos): + for prim, alt, i, j, alt_type in self._find_splice_bubbles_at_position( + types, pos + ): yield list(prim), list(alt), i, j, alt_type return - if any(type in types for type in ('ES', '3AS', '5AS', 'IR', 'ME')): + if any(type in types for type in ("ES", "3AS", "5AS", "IR", "ME")): # list of spliced and unspliced transcripts joining in B inB_sets: list[tuple[set[int], set[int]]] = [(set(), set())] # node_matrix=self.get_node_matrix() @@ -882,36 +1194,74 @@ def find_splice_bubbles(self, types: Optional[str | list[ASEType]] = None, pos=N for transcript_id, node_id in node_A.suc.items(): outA_sets.setdefault(node_id, set()).add(transcript_id) unspliced = node_A.end == self[junctions[0]].start - alternative: tuple[set[int], set[int]] = ({}, outA_sets[junctions[0]]) if unspliced else (outA_sets[junctions[0]], {}) + alternative: tuple[set[int], set[int]] = ( + ({}, outA_sets[junctions[0]]) + if unspliced + else (outA_sets[junctions[0]], {}) + ) # node_C_dict aims to avoid recalculation of node_C for ME events - # transcript_id -> node at start of 2nd exon C for transcript_id such that there is one exon (B) (and both flanking introns) between node_A and C; None if transcript ends + # transcript_id -> node at start of 2nd exon C for transcript_id such that there is one exon (B) (and both flanking introns) between node_A and C; + # None if transcript ends node_C_dict: dict[int, int | None] = {} # ensure that only ME events with novel transcript_id are reported me_alt_seen = set() - logger.debug('checking node %s: %s (%s)', i, node_A, list(zip(junctions, [outA_sets[j] for j in junctions]))) + logger.debug( + "checking node %s: %s (%s)", + i, + node_A, + list(zip(junctions, [outA_sets[j] for j in junctions])), + ) # start from second, as first does not have an alternative for j_idx, junction in enumerate(junctions[1:], 1): # check that transcripts extend beyond node_B - alternative = [{transcript_id for transcript_id in alternative[i] if self._pas[transcript_id] > junction} for i in range(2)] + alternative = [ + { + transcript_id + for transcript_id in alternative[i] + if self._pas[transcript_id] > junction + } + for i in range(2) + ] logger.debug(alternative) # alternative transcript sets for the 4 types - found = [trL1.intersection(trL2) for trL1 in alternative for trL2 in inB_sets[junction]] + found = [ + trL1.intersection(trL2) + for trL1 in alternative + for trL2 in inB_sets[junction] + ] # 5th type: mutually exclusive (outdated handling of ME for reference) # found.append(set.union(*alternative)-inB_sets[junction][0]-inB_sets[junction][1]) - logger.debug('checking junction %s (transcript_id=%s) and found %s at B=%s', junction, outA_sets[junction], found, inB_sets[junction]) + logger.debug( + "checking junction %s (transcript_id=%s) and found %s at B=%s", + junction, + outA_sets[junction], + found, + inB_sets[junction], + ) for alt_type_id, alt in enumerate(found): if alt_types[alt_type_id] in types and alt: - yield list(outA_sets[junction]), list(alt), i, junction, alt_types[alt_type_id] + yield list(outA_sets[junction]), list( + alt + ), i, junction, alt_types[alt_type_id] # me_alt=set.union(*alternative)-inB_sets[junction][0]-inB_sets[junction][1] #search 5th type: mutually exclusive - if 'ME' in types: + if "ME" in types: # search 5th type: mutually exclusive - needs to be spliced - me_alt = alternative[0] - inB_sets[junction][0] - inB_sets[junction][1] + me_alt = ( + alternative[0] + - inB_sets[junction][0] + - inB_sets[junction][1] + ) # there is at least one novel alternative transcript if me_alt - me_alt_seen: # for ME we need to find (potentially more than one) node_C where the alternatives rejoin # find node_C for all me_alt for transcript_id in me_alt: - node_C_dict.setdefault(transcript_id, self._get_next_spliced(transcript_id, node_A.suc[transcript_id])) + node_C_dict.setdefault( + transcript_id, + self._get_next_spliced( + transcript_id, node_A.suc[transcript_id] + ), + ) # transcript end in node_B, no node_C if node_C_dict[transcript_id] is None: # those are not of interest for ME @@ -923,23 +1273,40 @@ def find_splice_bubbles(self, types: Optional[str | list[ASEType]] = None, pos=N # primary transcripts for transcript_id in outA_sets[node_B_i]: # find node_C - node_C_dict.setdefault(transcript_id, self._get_next_spliced(transcript_id, node_B_i)) + node_C_dict.setdefault( + transcript_id, + self._get_next_spliced(transcript_id, node_B_i), + ) if node_C_dict[transcript_id] is None: continue # first, all primary transcript_id/nCs from junction added if node_B_i == junction: - inC_sets.setdefault(node_C_dict[transcript_id], set()).add(transcript_id) + inC_sets.setdefault( + node_C_dict[transcript_id], set() + ).add(transcript_id) # then add primary transcript_id that also rejoin at any of the junction nC elif node_C_dict[transcript_id] in inC_sets: - inC_sets[node_C_dict[transcript_id]].add(transcript_id) + inC_sets[node_C_dict[transcript_id]].add( + transcript_id + ) # no node_C for any of the junction primary transcript_id - no need to check the other primaries if not inC_sets: break for node_C_i, me_prim in sorted(inC_sets.items()): - found_alt = {transcript_id for transcript_id in me_alt if node_C_dict[transcript_id] == node_C_i} + found_alt = { + transcript_id + for transcript_id in me_alt + if node_C_dict[transcript_id] == node_C_i + } # ensure, there is a new alternative if found_alt - me_alt_seen: - yield (list(me_prim), list(found_alt), i, node_C_i, 'ME') + yield ( + list(me_prim), + list(found_alt), + i, + node_C_i, + "ME", + ) # me_alt=me_alt-found_alt me_alt_seen.update(found_alt) # now transcripts supporting junction join the alternatives @@ -947,15 +1314,18 @@ def find_splice_bubbles(self, types: Optional[str | list[ASEType]] = None, pos=N if "TSS" in types or "PAS" in types: yield from self._find_start_end_events(types) - def _find_start_end_events(self, types: list[ASEType]) -> Generator[ASEvent, None, None]: - '''Searches for alternative TSS/PAS in the segment graph. + def _find_start_end_events( + self, types: list[ASEType] + ) -> Generator[ASEvent, None, None]: + """Searches for alternative TSS/PAS in the segment graph. All transcripts sharing the same first/ last node in the splice graph are summarized. All pairs of TSS/PAS are returned. The primary set is the set with the smaller coordinate, the alternative the one with the larger coordinate. :return: Tuple with 1) transcript ids sharing common start exon and 2) alternative transcript ids respectively, - as well as 3) start and 4) end node ids of the exon and 5) type of alternative event ("TSS" or "PAS")''' + as well as 3) start and 4) end node ids of the exon and 5) type of alternative event ("TSS" or "PAS") + """ tss: dict[int, set[int]] = {} pas: dict[int, set[int]] = {} # tss_start: dict[int, int] = {} @@ -963,29 +1333,55 @@ def _find_start_end_events(self, types: list[ASEType]) -> Generator[ASEvent, Non for transcript_id, (start1, end1) in enumerate(zip(self._tss, self._pas)): tss.setdefault(start1, set()).add(transcript_id) pas.setdefault(end1, set()).add(transcript_id) - alt_types: list[ASEType] = ['PAS', 'TSS'] if self.strand == '-' else ['TSS', 'PAS'] + alt_types: list[ASEType] = ( + ["PAS", "TSS"] if self.strand == "-" else ["TSS", "PAS"] + ) if alt_types[0] in types: - for (prim_node_id, prim_set), (alt_node_id, alt_set) in itertools.combinations(sorted(tss.items(), key=lambda item: item[0]), 2): - yield (list(prim_set), list(alt_set), prim_node_id, alt_node_id, alt_types[0]) + for (prim_node_id, prim_set), ( + alt_node_id, + alt_set, + ) in itertools.combinations( + sorted(tss.items(), key=lambda item: item[0]), 2 + ): + yield ( + list(prim_set), + list(alt_set), + prim_node_id, + alt_node_id, + alt_types[0], + ) if alt_types[1] in types: - for (prim_node_id, prim_set), (alt_node_id, alt_set) in itertools.combinations(sorted(pas.items(), key=lambda item: item[0]), 2): - yield (list(prim_set), list(alt_set), prim_node_id, alt_node_id, alt_types[1]) + for (prim_node_id, prim_set), ( + alt_node_id, + alt_set, + ) in itertools.combinations( + sorted(pas.items(), key=lambda item: item[0]), 2 + ): + yield ( + list(prim_set), + list(alt_set), + prim_node_id, + alt_node_id, + alt_types[1], + ) def is_exonic(self, position): - '''Checks whether the position is within an exon. + """Checks whether the position is within an exon. :param position: The genomic position to check. - :return: True, if the position overlaps with an exon, else False.''' + :return: True, if the position overlaps with an exon, else False.""" for node in self: if node[0] <= position and node[1] >= position: return True return False def _get_all_exons(self, nodeX, nodeY, transcript): - 'get all exonic regions between (including) nodeX to nodeY for transcripts transcript' + "get all exonic regions between (including) nodeX to nodeY for transcripts transcript" # TODO: add option to extend first and last exons node = max(nodeX, self._tss[transcript]) # if tss>nodeX start there - if transcript not in self[node].pre and self._tss[transcript] != node: # nodeX is not part of transcript + if ( + transcript not in self[node].pre and self._tss[transcript] != node + ): # nodeX is not part of transcript # find first node in transcript after nodeX but before nodeY for i in range(node, nodeY + 1): if transcript in self[node].suc: @@ -1017,23 +1413,31 @@ def __len__(self): return len(self._graph) def events_dist(self, event1: ASEvent, event2: ASEvent): - ''' + """ returns the distance (in nucleotides) between two Alternative Splicing Events. :param event1: event obtained from .find_splice_bubbles() :param event2: event obtained from .find_splice_bubbles() - ''' + """ # the event begins at the beginning of the first exon and ends at the end of the last exon - e1_coor = [self[event1[2]].start, self[event1[3]].end] # starting and ending coordinates of event 1 - e2_coor = [self[event2[2]].start, self[event2[3]].end] # starting and ending coordinates of event 2 + e1_coor = [ + self[event1[2]].start, + self[event1[3]].end, + ] # starting and ending coordinates of event 1 + e2_coor = [ + self[event2[2]].start, + self[event2[3]].end, + ] # starting and ending coordinates of event 2 return _interval_dist(e1_coor, e2_coor) - def _get_node_starting_at(self, coordinate: int, start_index=0) -> tuple[int, SegGraphNode]: - ''' + def _get_node_starting_at( + self, coordinate: int, start_index=0 + ) -> tuple[int, SegGraphNode]: + """ return the node in the splice graph starting at the given coordinate. - ''' + """ for i, node in enumerate(self[start_index:], start_index): if node.start == coordinate: return i, node @@ -1041,10 +1445,12 @@ def _get_node_starting_at(self, coordinate: int, start_index=0) -> tuple[int, Se return -1, None return -1, None - def _get_node_ending_at(self, coordinate: int, start_index=0) -> tuple[int, SegGraphNode]: - ''' + def _get_node_ending_at( + self, coordinate: int, start_index=0 + ) -> tuple[int, SegGraphNode]: + """ return the node in the splice graph ending at the given coordinate. - ''' + """ for i, node in enumerate(self[start_index:], start_index): if node.end == coordinate: return i, node @@ -1062,7 +1468,8 @@ def _get_event_coordinate(self, event: ASEvent): class SegGraphNode(tuple): - '''A node in a segment graph represents an exonic segment.''' + """A node in a segment graph represents an exonic segment.""" + def __new__(cls, start, end, pre=None, suc=None): if pre is None: pre = dict() @@ -1075,33 +1482,35 @@ def __getnewargs__(self): @property def start(self) -> int: - '''the (genomic 5') start of the segment''' + """the (genomic 5') start of the segment""" return self.__getitem__(0) @property def end(self) -> int: - '''the (genomic 3') end of the segment''' + """the (genomic 3') end of the segment""" return self.__getitem__(1) @property def pre(self) -> dict[int, int]: - '''the predecessor segments of the segment (linked nodes upstream)''' + """the predecessor segments of the segment (linked nodes upstream)""" return self.__getitem__(2) @property def suc(self) -> dict[int, int]: - '''the successor segments of the segment (linked nodes downstream)''' + """the successor segments of the segment (linked nodes downstream)""" return self.__getitem__(3) -class SpliceGraph(): - '''(Experimental) Splice Graph Implementation +class SpliceGraph: + """(Experimental) Splice Graph Implementation Nodes represent splice sites and are tuples of genomic positions and a "lower" flag. The "lower flag" is true, if the splice site is a genomic 5' end of an exon Nodes are kept sorted, so iteration over splicegraph returns all nodes in genomic order Edges are assessed with SpliceGraph.pre(node, [transcript_number]) and SpliceGraph.suc(node, [transcript_number]) functions. - If no transcript_number is provided, a dict with all incoming/outgoing edges is returned''' + If no transcript_number is provided, a dict with all incoming/outgoing edges is returned + """ + # @experimental def __init__(self, is_reverse, graph, fwd_starts, rev_starts): @@ -1112,15 +1521,15 @@ def __init__(self, is_reverse, graph, fwd_starts, rev_starts): @classmethod def from_transcript_list(cls, exon_lists, strand): - '''Compute the splice graph from a list of transcripts + """Compute the splice graph from a list of transcripts :param exon_lists: A list of transcripts, which are lists of exons, which in turn are (start,end) tuples :type exon_lists: list :param strand: the strand of the gene, either "+" or "-" :return: The SpliceGraph object - :rtype: SpliceGraph''' + :rtype: SpliceGraph""" - assert strand in '+-', 'strand must be either "+" or "-"' + assert strand in "+-", 'strand must be either "+" or "-"' graph = SortedDict() fwd_starts = [(exons[0][0], True) for exons in exon_lists] # genomic 5' @@ -1129,11 +1538,17 @@ def from_transcript_list(cls, exon_lists, strand): for transcript_number, exons in enumerate(exon_lists): graph.setdefault((exons[0][0], True), ({}, {})) - for i, (b1, b2) in enumerate(pairwise(pos for exon in exons for pos in exon)): + for i, (b1, b2) in enumerate( + pairwise(pos for exon in exons for pos in exon) + ): graph.setdefault((b2, bool(i % 2)), ({}, {})) - graph[b2, bool(i % 2)][1][transcript_number] = b1, not bool((i) % 2) # successor - graph[b1, not bool(i % 2)][0][transcript_number] = b2, bool((1) % 2) # predesessor - sg = cls(strand == '-', graph, fwd_starts, rev_starts) + graph[b2, bool(i % 2)][1][transcript_number] = b1, not bool( + (i) % 2 + ) # successor + graph[b1, not bool(i % 2)][0][transcript_number] = b2, bool( + (1) % 2 + ) # predesessor + sg = cls(strand == "-", graph, fwd_starts, rev_starts) return sg def __iter__(self): @@ -1144,12 +1559,12 @@ def __len__(self): @experimental def add(self, exons) -> None: - ''' + """ Add one transcript to the existing graph. :param exons: A list of exon tuples representing the transcript to add :type exons: list - ''' + """ transcript_number = len(self._fwd_starts) self._fwd_starts.append((exons[0][0], True)) # genomic 5' @@ -1157,28 +1572,30 @@ def add(self, exons) -> None: self._graph.setdefault((exons[0][0], True), ({}, {})) for i, (b1, b2) in enumerate(pairwise(pos for exon in exons for pos in exon)): self._graph.setdefault((b2, bool(i % 2)), ({}, {})) - self._graph[b2, bool(i % 2)][1][transcript_number] = b1, not bool((i) % 2) # successor + self._graph[b2, bool(i % 2)][1][transcript_number] = b1, not bool( + (i) % 2 + ) # successor self._graph[b1, not bool(i % 2)][0][transcript_number] = b2, bool((1) % 2) def suc(self, node, transcript_number=None) -> Union[int, dict]: - '''get index of successor node (next genomic upstream node) of transcript, or, if transcript_number is omitted, a dict with successors for all transcripts. + """get index of successor node (next genomic upstream node) of transcript, or, if transcript_number is omitted, a dict with successors for all transcripts. :param node: index of the originating node :type node: int :param transcript_number: index of the transcript (optional) - :type transcript_number: int''' + :type transcript_number: int""" edges = self._graph[node][0] if transcript_number is None: return edges return edges[transcript_number] def pre(self, node, transcript_number=None) -> Union[int, dict]: - '''get index of predesessor node (next genomic downstream node) of transcript, or, if transcript_number is omitted, a dict with predesessors for all transcripts. + """get index of predesessor node (next genomic downstream node) of transcript, or, if transcript_number is omitted, a dict with predesessors for all transcripts. :param node: index of the originating node :type node: int :param transcript_number: index of the transcript (optional) - :type transcript_number: int''' + :type transcript_number: int""" edges = self._graph[node][1] if transcript_number is None: diff --git a/src/isotools/transcriptome.py b/src/isotools/transcriptome.py index 553152e..8330fd6 100644 --- a/src/isotools/transcriptome.py +++ b/src/isotools/transcriptome.py @@ -7,9 +7,16 @@ from typing import Optional, TypedDict from ._transcriptome_io import import_ref_transcripts from .gene import Gene -from ._transcriptome_filter import DEFAULT_GENE_FILTER, DEFAULT_TRANSCRIPT_FILTER, DEFAULT_REF_TRANSCRIPT_FILTER, ANNOTATION_VOCABULARY, SPLICE_CATEGORY +from ._transcriptome_filter import ( + DEFAULT_GENE_FILTER, + DEFAULT_TRANSCRIPT_FILTER, + DEFAULT_REF_TRANSCRIPT_FILTER, + ANNOTATION_VOCABULARY, + SPLICE_CATEGORY, +) from . import __version__ -logger = logging.getLogger('isotools') + +logger = logging.getLogger("isotools") # as this class has diverse functionality, its split among: # transcriptome.py (this file- initialization and user level basic functions) @@ -24,224 +31,297 @@ class FilterData(TypedDict): transcript: dict[str, str] reference: dict[str, str] + class InfosData(TypedDict): biases: bool + class Transcriptome: - '''Contains sequencing data and annotation for Long Read Transcriptome Sequencing (LRTS) Experiments. - ''' + """Contains sequencing data and annotation for Long Read Transcriptome Sequencing (LRTS) Experiments.""" + # initialization and save/restore data data: dict[str, IntervalTree[Gene]] - 'One IntervalTree of Genes for each chromosome.' + "One IntervalTree of Genes for each chromosome." infos: InfosData chimeric: dict filter: FilterData _idx: dict[str, Gene] - def __init__(self, data: Optional[dict[str, IntervalTree[Gene]]] = None, infos = dict(), chimeric = dict(), filter = dict()): - '''Constructor method''' + def __init__( + self, + data: Optional[dict[str, IntervalTree[Gene]]] = None, + infos=None, + chimeric=None, + filter=None, + ): + if infos is None: + infos = {} + if chimeric is None: + chimeric = {} + if filter is None: + filter = {} + + """Constructor method""" if data is not None: self.data = data self.infos = infos self.chimeric = chimeric self.filter = filter - assert 'reference_file' in self.infos + assert "reference_file" in self.infos self.make_index() @classmethod - def from_reference(cls, reference_file: str, file_format='auto', **kwargs): - '''Creates a Transcriptome object by importing reference annotation. + def from_reference(cls, reference_file: str, file_format="auto", **kwargs): + """Creates a Transcriptome object by importing reference annotation. :param reference_file: Reference file in gff3 format or pickle file to restore previously imported annotation :param file_format: Specify the file format of the provided reference_file. If set to "auto" the file type is inferred from the extension. - :param chromosome: If reference file is gtf/gff, restrict import on specified chromosomes ''' - - if file_format == 'auto': - file_format = os.path.splitext(reference_file)[1].lstrip('.') - if file_format == 'gz': - file_format = os.path.splitext(reference_file[:-3])[1].lstrip('.') - if file_format in ('gff', 'gff3', 'gtf'): - logger.info('importing reference from %s file %s', file_format, reference_file) + :param chromosome: If reference file is gtf/gff, restrict import on specified chromosomes + """ + + if file_format == "auto": + file_format = os.path.splitext(reference_file)[1].lstrip(".") + if file_format == "gz": + file_format = os.path.splitext(reference_file[:-3])[1].lstrip(".") + if file_format in ("gff", "gff3", "gtf"): + logger.info( + "importing reference from %s file %s", file_format, reference_file + ) transcriptome = cls() transcriptome.chimeric = {} - transcriptome.data = import_ref_transcripts(reference_file, transcriptome, file_format, **kwargs) - transcriptome.infos = {'reference_file': reference_file, 'isotools_version': __version__} - transcriptome.filter = {'gene': DEFAULT_GENE_FILTER.copy(), - 'transcript': DEFAULT_TRANSCRIPT_FILTER.copy(), - 'reference': DEFAULT_REF_TRANSCRIPT_FILTER.copy()} + transcriptome.data = import_ref_transcripts( + reference_file, transcriptome, file_format, **kwargs + ) + transcriptome.infos = { + "reference_file": reference_file, + "isotools_version": __version__, + } + transcriptome.filter = { + "gene": DEFAULT_GENE_FILTER.copy(), + "transcript": DEFAULT_TRANSCRIPT_FILTER.copy(), + "reference": DEFAULT_REF_TRANSCRIPT_FILTER.copy(), + } for subcat in ANNOTATION_VOCABULARY: - tag = '_'.join(re.findall(r'\b\w+\b', subcat)).upper() + tag = "_".join(re.findall(r"\b\w+\b", subcat)).upper() if tag[0].isdigit(): - tag = '_'+tag - transcriptome.filter['transcript'][tag] = f'"{subcat}" in annotation[1]' + tag = "_" + tag + transcriptome.filter["transcript"][tag] = f'"{subcat}" in annotation[1]' for i, cat in enumerate(SPLICE_CATEGORY): - transcriptome.filter['transcript'][cat] = f'annotation[0]=={i}' + transcriptome.filter["transcript"][cat] = f"annotation[0]=={i}" - elif file_format == 'pkl': + elif file_format == "pkl": # warn if kwargs are specified: kwargs are ignored if kwargs: - logger.warning("The following parameters are ignored when loading reference from pkl: %s", ", ".join(kwargs)) + logger.warning( + "The following parameters are ignored when loading reference from pkl: %s", + ", ".join(kwargs), + ) transcriptome = cls.load(reference_file) - if 'sample_table' in transcriptome.infos: - logger.warning('the pickle file seems to contain sample information... extracting reference') + if "sample_table" in transcriptome.infos: + logger.warning( + "the pickle file seems to contain sample information... extracting reference" + ) transcriptome = transcriptome._extract_reference() else: - raise ValueError('invalid file format %s of file %s' % (file_format, reference_file)) + raise ValueError( + "invalid file format %s of file %s" % (file_format, reference_file) + ) transcriptome.make_index() return transcriptome def save(self, pickle_file: str): - '''Saves transcriptome information (including reference) in a pickle file. + """Saves transcriptome information (including reference) in a pickle file. - :param pickle_file: Filename to save data''' - logger.info('saving transcriptome to %s', pickle_file) - pickle.dump(self, open(pickle_file, 'wb')) + :param pickle_file: Filename to save data""" + logger.info("saving transcriptome to %s", pickle_file) + pickle.dump(self, open(pickle_file, "wb")) @classmethod def load(cls, pickle_file: str): - '''Restores transcriptome information from a pickle file. + """Restores transcriptome information from a pickle file. - :param pickle_file: Filename to restore data''' + :param pickle_file: Filename to restore data""" - logger.info('loading transcriptome from %s', pickle_file) - transcriptome: Transcriptome = pickle.load(open(pickle_file, 'rb')) - pickled_version = transcriptome.infos.get('isotools_version', '<0.2.6') + logger.info("loading transcriptome from %s", pickle_file) + transcriptome: Transcriptome = pickle.load(open(pickle_file, "rb")) + pickled_version = transcriptome.infos.get("isotools_version", "<0.2.6") if pickled_version != __version__: - logger.warning('This is isotools version %s, but data has been pickled with version %s, which may be incompatible', __version__, pickled_version) + logger.warning( + "This is isotools version %s, but data has been pickled with version %s, which may be incompatible", + __version__, + pickled_version, + ) transcriptome.make_index() return transcriptome def save_reference(self, pickle_file=None): - '''Saves the reference information of a transcriptome in a pickle file. + """Saves the reference information of a transcriptome in a pickle file. - :param pickle_file: Filename to save data''' + :param pickle_file: Filename to save data""" if pickle_file is None: - pickle_file = self.infos['reference_file']+'.isotools.pkl' - logger.info('saving reference to %s', pickle_file) + pickle_file = self.infos["reference_file"] + ".isotools.pkl" + logger.info("saving reference to %s", pickle_file) ref_tr = self._extract_reference() - pickle.dump(ref_tr, open(pickle_file, 'wb')) + pickle.dump(ref_tr, open(pickle_file, "wb")) def _extract_reference(self): - if not 'sample_table' not in self.infos: + if not "sample_table" not in self.infos: return self # only reference info - assume that self.data only contains reference data # make a new transcriptome - ref_info = {k: v for k, v in self.infos.items() if k in ['reference_file', 'isotools_version']} + ref_info = { + k: v + for k, v in self.infos.items() + if k in ["reference_file", "isotools_version"] + } ref_transcriptome = type(self)(data={}, infos=ref_info, filter=self.filter) # extract the reference genes and link them to the new ref_tr - keep = {'ID', 'chr', 'strand', 'name', 'reference'} # no coverage, segment_graph, transcripts + keep = { + "ID", + "chr", + "strand", + "name", + "reference", + } # no coverage, segment_graph, transcripts for chrom, tree in self.data.items(): - ref_transcriptome.data[chrom] = IntervalTree(Gene(gene.start, gene.end, {k: v - for k, v in gene.data.items() if k in keep}, ref_transcriptome) - for gene in tree if gene.is_annotated) + ref_transcriptome.data[chrom] = IntervalTree( + Gene( + gene.start, + gene.end, + {k: v for k, v in gene.data.items() if k in keep}, + ref_transcriptome, + ) + for gene in tree + if gene.is_annotated + ) ref_transcriptome.make_index() return ref_transcriptome def make_index(self): - '''Updates the index of gene names and ids (e.g. used by the the [] operator).''' + """Updates the index of gene names and ids (e.g. used by the the [] operator).""" idx = dict() for gene in self: if gene.id in idx: # at least id should be unique - maybe raise exception? - logger.warning('%s seems to be ambigous: %s vs %s', gene.id, str(idx[gene.id]), str(gene)) + logger.warning( + "%s seems to be ambigous: %s vs %s", + gene.id, + str(idx[gene.id]), + str(gene), + ) idx[gene.name] = gene idx[gene.id] = gene self._idx = idx # basic user level functionality def __getitem__(self, key): - ''' + """ Syntax: self[key] :param key: May either be the gene name or the gene id - :return: The gene specified by key.''' + :return: The gene specified by key.""" return self._idx[key] def __len__(self): - '''Syntax: len(self) + """Syntax: len(self) - :return: The number of genes.''' + :return: The number of genes.""" return self.n_genes def __contains__(self, key): - ''' Syntax: key in self + """Syntax: key in self Checks whether key is in self. - :param key: May either be the gene name or the gene id''' + :param key: May either be the gene name or the gene id""" return key in self._idx def remove_chromosome(self, chromosome): - '''Deletes the chromosome from the transcriptome + """Deletes the chromosome from the transcriptome - :param chromosome: Name of the chromosome to remove''' + :param chromosome: Name of the chromosome to remove""" del self.data[chromosome] self.make_index() - def _get_sample_idx(self, name_column='name'): - 'a dict with group names as keys and index lists as values' + def _get_sample_idx(self, name_column="name"): + "a dict with group names as keys and index lists as values" return {sample: i for i, sample in enumerate(self.sample_table[name_column])} @property def sample_table(self): - '''The sample table contains sample names, group information, long read coverage, as well as all other potentially - relevant information on the samples.''' + """The sample table contains sample names, group information, long read coverage, as well as all other potentially + relevant information on the samples.""" try: - return self.infos['sample_table'] + return self.infos["sample_table"] except KeyError: - return pd.DataFrame(columns=['name', 'file', 'group', 'nonchimeric_reads', 'chimeric_reads'], dtype='object') + return pd.DataFrame( + columns=[ + "name", + "file", + "group", + "nonchimeric_reads", + "chimeric_reads", + ], + dtype="object", + ) @property def samples(self) -> list: - '''An ordered list of sample names.''' + """An ordered list of sample names.""" return list(self.sample_table.name) - def groups(self, by='group') -> dict: - '''Get sample groups as defined in columns of the sample table. + def groups(self, by="group") -> dict: + """Get sample groups as defined in columns of the sample table. :param by: A column name of the sample table that defines the grouping. :return: Dict with groupnames as keys and list of sample names as values. - ''' - return dict(self.sample_table.groupby(by)['name'].apply(list)) + """ + return dict(self.sample_table.groupby(by)["name"].apply(list)) @property def n_transcripts(self) -> int: - '''The total number of transcripts isoforms.''' + """The total number of transcripts isoforms.""" if self.data is None: return 0 return sum(gene.n_transcripts for gene in self) @property def n_genes(self) -> int: - '''The total number of genes.''' + """The total number of genes.""" if self.data is None: return 0 return sum((len(t) for t in self.data.values())) @property def novel_genes(self) -> int: # this is used for id assignment - '''The total number of novel (not reference) genes.''' + """The total number of novel (not reference) genes.""" try: - return self.infos['novel_counter'] + return self.infos["novel_counter"] except KeyError: - self.infos['novel_counter'] = 0 + self.infos["novel_counter"] = 0 return 0 @property def chromosomes(self) -> list: - '''The list of chromosome names.''' + """The list of chromosome names.""" return list(self.data) - def _add_novel_gene(self, chrom, start, end, strand, info, novel_prefix='IT_novel_'): + def _add_novel_gene( + self, chrom, start, end, strand, info, novel_prefix="IT_novel_" + ): n_novel = self.novel_genes - info.update({'chr': chrom, 'ID': f'{novel_prefix}{n_novel+1:05d}', 'strand': strand}) + info.update( + {"chr": chrom, "ID": f"{novel_prefix}{n_novel+1:05d}", "strand": strand} + ) g = Gene(start, end, info, self) self.data.setdefault(chrom, IntervalTree()).add(g) - self.infos['novel_counter'] += 1 + self.infos["novel_counter"] += 1 return g def __str__(self): - return '{} object with {} genes and {} transcripts'.format(type(self).__name__, self.n_genes, self.n_transcripts) + return "{} object with {} genes and {} transcripts".format( + type(self).__name__, self.n_genes, self.n_transcripts + ) def __repr__(self): return object.__repr__(self) @@ -250,8 +330,14 @@ def __iter__(self): return (gene for tree in self.data.values() for gene in tree) # IO: load new data from primary data files - from ._transcriptome_io import add_sample_from_bam, add_sample_from_csv, remove_samples, add_short_read_coverage, \ - remove_short_read_coverage, collapse_immune_genes + from ._transcriptome_io import ( + add_sample_from_bam, + add_sample_from_csv, + remove_samples, + add_short_read_coverage, + remove_short_read_coverage, + collapse_immune_genes, + ) # IO: output data as tables or other human readable format from ._transcriptome_io import ( @@ -266,15 +352,38 @@ def __iter__(self): ) # filtering functionality and iterators - from ._transcriptome_filter import add_qc_metrics, add_orf_prediction, add_filter, remove_filter, iter_genes, iter_transcripts, iter_ref_transcripts + from ._transcriptome_filter import ( + add_qc_metrics, + add_orf_prediction, + add_filter, + remove_filter, + iter_genes, + iter_transcripts, + iter_ref_transcripts, + ) # statistic: differential splicing, alternative_splicing_events - from ._transcriptome_stats import die_test, altsplice_test, coordination_test, alternative_splicing_events, rarefaction + from ._transcriptome_stats import ( + die_test, + altsplice_test, + coordination_test, + alternative_splicing_events, + rarefaction, + ) # statistic: summary tables (can be used as input to plot_bar / plot_dist) - from ._transcriptome_stats import altsplice_stats, filter_stats, transcript_length_hist, transcript_coverage_hist, \ - transcripts_per_gene_hist, exons_per_transcript_hist, downstream_a_hist, direct_repeat_hist, \ - entropy_calculation, str_var_calculation + from ._transcriptome_stats import ( + altsplice_stats, + filter_stats, + transcript_length_hist, + transcript_coverage_hist, + transcripts_per_gene_hist, + exons_per_transcript_hist, + downstream_a_hist, + direct_repeat_hist, + entropy_calculation, + str_var_calculation, + ) # protein domain annotation from .domains import add_hmmer_domains, add_annotation_domains diff --git a/tests/altsplice_test.py b/tests/altsplice_test.py index d1c6004..c886023 100644 --- a/tests/altsplice_test.py +++ b/tests/altsplice_test.py @@ -6,5 +6,5 @@ def test_subcategory(example_gene): sg = example_gene.ref_segment_graph for novel in example_gene.transcripts: - alt_splice = sg.get_alternative_splicing(novel['exons']) - assert novel['transcript_name'] in alt_splice[1] + alt_splice = sg.get_alternative_splicing(novel["exons"]) + assert novel["transcript_name"] in alt_splice[1] diff --git a/tests/conftest.py b/tests/conftest.py index 99dd9f2..e7e615c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,40 +4,74 @@ @pytest.fixture(scope="session") def example_gene(): - ref = [[(12, 20), (30, 40), (50, 60), (70, 81)], - [(11, 20), (35, 40), (75, 79)], - [(10, 20), (30, 40), (50, 60), (70, 80)]] - novel = {'FSM': [(10, 20), (30, 40), (50, 60), (70, 80)], - "3' fragment": [(33, 40), (50, 60), (70, 80)], - "5' fragment": [(10, 20), (30, 40), (50, 55)], - "mono-exon": [(22, 35)], - "exon skipping": [(10, 20), (50, 60), (70, 80)], - "intron retention": [(10, 40), (50, 60), (70, 80)], - "novel combination": [(10, 20), (30, 40), (75, 80)], - "novel junction": [(10, 20), (30, 40), (50, 60), (75, 80)], - "novel exonic TSS": [(26, 40), (50, 60), (70, 80)], - "novel exonic PAS": [(10, 20), (30, 40), (50, 66)], - "novel 5' splice site": [(10, 24), (30, 40), (50, 60), (70, 80)], - "novel 3' splice site": [(10, 20), (26, 40), (50, 60), (70, 80)], - "novel exon": [(10, 20), (30, 40), (43, 47), (50, 60), (70, 80)], - "novel intronic TSS": [(43, 47), (50, 60), (70, 80)], - "novel intronic PAS": [(10, 20), (30, 40), (82, 90)]} - ref_tr = [{'exons': e, 'id': f'reference {i+1}'} for i, e in enumerate(ref)] - transcripts = [{'exons': e, 'transcript_name': n} for n, e in novel.items()] - g = Gene(10, 81, {'chr': 'chr1', 'strand': '+', 'ID': 'example', - 'reference': {'transcripts': ref_tr}, 'transcripts': transcripts}, None) + ref = [ + [(12, 20), (30, 40), (50, 60), (70, 81)], + [(11, 20), (35, 40), (75, 79)], + [(10, 20), (30, 40), (50, 60), (70, 80)], + ] + novel = { + "FSM": [(10, 20), (30, 40), (50, 60), (70, 80)], + "3' fragment": [(33, 40), (50, 60), (70, 80)], + "5' fragment": [(10, 20), (30, 40), (50, 55)], + "mono-exon": [(22, 35)], + "exon skipping": [(10, 20), (50, 60), (70, 80)], + "intron retention": [(10, 40), (50, 60), (70, 80)], + "novel combination": [(10, 20), (30, 40), (75, 80)], + "novel junction": [(10, 20), (30, 40), (50, 60), (75, 80)], + "novel exonic TSS": [(26, 40), (50, 60), (70, 80)], + "novel exonic PAS": [(10, 20), (30, 40), (50, 66)], + "novel 5' splice site": [(10, 24), (30, 40), (50, 60), (70, 80)], + "novel 3' splice site": [(10, 20), (26, 40), (50, 60), (70, 80)], + "novel exon": [(10, 20), (30, 40), (43, 47), (50, 60), (70, 80)], + "novel intronic TSS": [(43, 47), (50, 60), (70, 80)], + "novel intronic PAS": [(10, 20), (30, 40), (82, 90)], + } + ref_tr = [{"exons": e, "id": f"reference {i+1}"} for i, e in enumerate(ref)] + transcripts = [{"exons": e, "transcript_name": n} for n, e in novel.items()] + g = Gene( + 10, + 81, + { + "chr": "chr1", + "strand": "+", + "ID": "example", + "reference": {"transcripts": ref_tr}, + "transcripts": transcripts, + }, + None, + ) return g @pytest.fixture(scope="session") def example_gene_coor(): ref = [[(0, 10), (20, 30), (40, 50), (60, 70), (80, 90), (100, 110), (120, 130)]] - novel = {'altA_priB': [(0, 10), (20, 30), (40, 50), (60, 70), (100, 110), (120, 130)], - 'altA_altB': [(0, 10), (20, 30), (40, 50), (60, 70), (80, 90), (100, 110), (120, 130)], - 'priA_priB': [(0, 10), (20, 50), (60, 70), (100, 110), (120, 130)], - 'priB_altB': [(0, 10), (20, 50), (60, 70), (80, 90), (100, 110), (120, 130)]} - ref_tr = [{'exons': e, 'id': f'reference {i+1}'} for i, e in enumerate(ref)] - transcripts = [{'exons': e, 'transcript_name': n} for n, e in novel.items()] - g = Gene(1, 100, {'chr': 'chr1', 'strand': '+', 'ID': 'example_coor', - 'reference': {'transcripts': ref_tr}, 'transcripts': transcripts}, None) + novel = { + "altA_priB": [(0, 10), (20, 30), (40, 50), (60, 70), (100, 110), (120, 130)], + "altA_altB": [ + (0, 10), + (20, 30), + (40, 50), + (60, 70), + (80, 90), + (100, 110), + (120, 130), + ], + "priA_priB": [(0, 10), (20, 50), (60, 70), (100, 110), (120, 130)], + "priB_altB": [(0, 10), (20, 50), (60, 70), (80, 90), (100, 110), (120, 130)], + } + ref_tr = [{"exons": e, "id": f"reference {i+1}"} for i, e in enumerate(ref)] + transcripts = [{"exons": e, "transcript_name": n} for n, e in novel.items()] + g = Gene( + 1, + 100, + { + "chr": "chr1", + "strand": "+", + "ID": "example_coor", + "reference": {"transcripts": ref_tr}, + "transcripts": transcripts, + }, + None, + ) return g diff --git a/tests/coordination_test.py b/tests/coordination_test.py index e457a10..7536142 100644 --- a/tests/coordination_test.py +++ b/tests/coordination_test.py @@ -2,9 +2,13 @@ def test_coordination(example_gene_coor): - example_gene_coor.data['coverage'] = np.array([[300, 60, 100, 350]]) + example_gene_coor.data["coverage"] = np.array([[300, 60, 100, 350]]) res_pos = example_gene_coor.coordination_test(test="fisher") - assert res_pos[0][8] >= 0 and res_pos[0][9] < .05, 'Test should yield significant p-value' - example_gene_coor.data['coverage'] = np.array([[310, 380, 310, 350]]) + assert ( + res_pos[0][8] >= 0 and res_pos[0][9] < 0.05 + ), "Test should yield significant p-value" + example_gene_coor.data["coverage"] = np.array([[310, 380, 310, 350]]) res_neg = example_gene_coor.coordination_test(test="fisher") - assert res_neg[0][8] > .1 and res_neg[0][9] <= 1, 'Test should not yield significant p-value' + assert ( + res_neg[0][8] > 0.1 and res_neg[0][9] <= 1 + ), "Test should not yield significant p-value" diff --git a/tests/data/prepare_data.py b/tests/data/prepare_data.py index 9a6aec5..ef86550 100644 --- a/tests/data/prepare_data.py +++ b/tests/data/prepare_data.py @@ -15,33 +15,72 @@ def proportion(x): try: x = float(x) except ValueError as e: - raise argparse.ArgumentTypeError("%r not a floating-point literal" % (x,)) from e + raise argparse.ArgumentTypeError( + "%r not a floating-point literal" % (x,) + ) from e if x < 0.0 or x > 1.0: raise argparse.ArgumentTypeError("%r not in range [0.0, 1.0]" % (x,)) return x - parser = argparse.ArgumentParser(prog='prepare_data', description='subset genomic data for testing') - parser.add_argument('--annotation', metavar='', help='specify reference annotation [requires tabix index]') - parser.add_argument('--genome', metavar='', help='specify reference genome file [requires fai index]') - parser.add_argument('--alignment', metavar='', help='specify alignment bam file [requires bai index]') - parser.add_argument('--subsample_alignment', metavar='', help='subsample the alignment file', type=proportion) - parser.add_argument('--regions', metavar='', - default='chr2:214700000-216200000,chr8:22000000-23000000', help='regions to subset') - parser.add_argument('--out_prefix', metavar='', default='./example', help='specify output path and prefix') + parser = argparse.ArgumentParser( + prog="prepare_data", description="subset genomic data for testing" + ) + parser.add_argument( + "--annotation", + metavar="", + help="specify reference annotation [requires tabix index]", + ) + parser.add_argument( + "--genome", + metavar="", + help="specify reference genome file [requires fai index]", + ) + parser.add_argument( + "--alignment", + metavar="", + help="specify alignment bam file [requires bai index]", + ) + parser.add_argument( + "--subsample_alignment", + metavar="", + help="subsample the alignment file", + type=proportion, + ) + parser.add_argument( + "--regions", + metavar="", + default="chr2:214700000-216200000,chr8:22000000-23000000", + help="regions to subset", + ) + parser.add_argument( + "--out_prefix", + metavar="", + default="./example", + help="specify output path and prefix", + ) args = parser.parse_args() # parse regions - regions = [(regstr[:regstr.find(':')], )+tuple(int(i) for i in regstr[regstr.find(':')+1:].split('-')) for regstr in args.regions.split(',')] + regions = [ + (regstr[: regstr.find(":")],) + + tuple(int(i) for i in regstr[regstr.find(":") + 1 :].split("-")) + for regstr in args.regions.split(",") + ] # get subsets random.seed(42) # make sure the same reads are picked if a fraction is specified if args.annotation: - subset_annotation(args.annotation, regions, args.out_prefix+'.gff.gz') + subset_annotation(args.annotation, regions, args.out_prefix + ".gff.gz") if args.genome: - subset_genome(args.genome, regions, args.out_prefix+'.fa') + subset_genome(args.genome, regions, args.out_prefix + ".fa") if args.alignment: - subset_alignment(args.alignment, regions, args.out_prefix+'.bam', proportion=args.subsample_alignment) + subset_alignment( + args.alignment, + regions, + args.out_prefix + ".bam", + proportion=args.subsample_alignment, + ) if args.domains: - subset_domains(args.domains, regions, args.out_prefix+'_domains.csv') + subset_domains(args.domains, regions, args.out_prefix + "_domains.csv") return 0 # no error @@ -53,22 +92,30 @@ def subset_genome(genome_fn, regions, out_fn="example_genome.fa"): with open(out_fn, "w", encoding="utf8") as outfh: with pysam.FastaFile(genome_fn) as genome_fh: for reg in regions: - outfh.write(f'>{reg[0]}_part\n') - seq = (genome_fh.fetch(*reg)) - offset += len(reg[0])+7 - fai.append((f'{reg[0]}_part', len(seq), offset, line_length, line_length+1)) - for line in (seq[i:i+line_length] for i in range(0, len(seq), line_length)): - outfh.write(line+'\n') - offset += len(line)+1 - with open(out_fn+'.fai', "w", encoding="utf8") as outfh: + outfh.write(f">{reg[0]}_part\n") + seq = genome_fh.fetch(*reg) + offset += len(reg[0]) + 7 + fai.append( + (f"{reg[0]}_part", len(seq), offset, line_length, line_length + 1) + ) + for line in ( + seq[i : i + line_length] for i in range(0, len(seq), line_length) + ): + outfh.write(line + "\n") + offset += len(line) + 1 + with open(out_fn + ".fai", "w", encoding="utf8") as outfh: for idx, reg in zip(fai, regions): - outfh.write('\t'.join(str(v) for v in idx)+'\n') - print(f'extracted {idx[1]} bases from {reg[0]}:{reg[1]}-{reg[2]}') + outfh.write("\t".join(str(v) for v in idx) + "\n") + print(f"extracted {idx[1]} bases from {reg[0]}:{reg[1]}-{reg[2]}") def subset_alignment(bam_fn, regions, out_fn="example.bam", proportion=None): - header = {'HD': {'VN': '1.0'}, - 'SQ': [{'LN': end-start, 'SN': chrom+'_part'} for chrom, start, end in regions]} + header = { + "HD": {"VN": "1.0"}, + "SQ": [ + {"LN": end - start, "SN": chrom + "_part"} for chrom, start, end in regions + ], + } with pysam.AlignmentFile(out_fn, "wb", header=header) as out_fh: with pysam.AlignmentFile(bam_fn, "rb") as align: for new_ref_id, reg in enumerate(regions): @@ -79,7 +126,11 @@ def subset_alignment(bam_fn, regions, out_fn="example.bam", proportion=None): if read.reference_start < reg[1] or read.reference_end > reg[2]: continue if read.is_paired: - if read.next_reference_id != read.reference_id or read.next_reference_start < reg[1] or read.next_reference_end > reg[2]: + if ( + read.next_reference_id != read.reference_id + or read.next_reference_start < reg[1] + or read.next_reference_end > reg[2] + ): continue read.next_reference_start -= reg[1] read.next_reference_id = new_ref_id @@ -87,46 +138,52 @@ def subset_alignment(bam_fn, regions, out_fn="example.bam", proportion=None): read.reference_id = new_ref_id out_fh.write(read) n_reads += 1 - print(f'extracted {n_reads} reads from {reg[0]}:{reg[1]}-{reg[2]}') + print(f"extracted {n_reads} reads from {reg[0]}:{reg[1]}-{reg[2]}") pysam.index(out_fn) # avoid pylint complaints: https://github.com/pysam-developers/pysam/issues/819 def subset_annotation(gff_fn, regions, out_fn="example_annotation.gff.gz"): gff = pysam.TabixFile(gff_fn) - out_str = '' + out_str = "" for reg in regions: genes = set() for line in gff.fetch(*reg): ls = line.split(sep="\t") - if ls[2] == 'gene' and int(ls[3]) > reg[1] and int(ls[4]) < reg[2]: - info = dict([pair.split('=', 1) for pair in ls[8].rstrip(';').split(";")]) - genes.add(info['gene_id']) - print(f'extracted {len(genes)} genes from {reg[0]}:{reg[1]}-{reg[2]}') + if ls[2] == "gene" and int(ls[3]) > reg[1] and int(ls[4]) < reg[2]: + info = dict( + [pair.split("=", 1) for pair in ls[8].rstrip(";").split(";")] + ) + genes.add(info["gene_id"]) + print(f"extracted {len(genes)} genes from {reg[0]}:{reg[1]}-{reg[2]}") for line in gff.fetch(*reg): ls = line.split(sep="\t") - ls[0] += '_part' - ls[3] = str(int(ls[3])-reg[1]) - ls[4] = str(int(ls[4])-reg[1]) - info = dict([pair.split('=', 1) for pair in ls[8].rstrip(';').split(";")]) - if info.get('gene_id', None) in genes: - out_str += ('\t'.join(ls) + '\n') + ls[0] += "_part" + ls[3] = str(int(ls[3]) - reg[1]) + ls[4] = str(int(ls[4]) - reg[1]) + info = dict([pair.split("=", 1) for pair in ls[8].rstrip(";").split(";")]) + if info.get("gene_id", None) in genes: + out_str += "\t".join(ls) + "\n" with Bio.bgzf.BgzfWriter(out_fn, "wb") as outfh: outfh.write(out_str) - _ = pysam.tabix_index(out_fn, preset='gff', force=True) + _ = pysam.tabix_index(out_fn, preset="gff", force=True) def subset_domains(domain_file, regions, out_fn): - anno_df = pd.read_csv(domain_file, sep='\t').rename({'#chrom': 'chrom'}, axis=1) + anno_df = pd.read_csv(domain_file, sep="\t").rename({"#chrom": "chrom"}, axis=1) subset_df = [] for chrom, start, end in regions: - sel = anno_df.query(f'chrom=="{chrom}" and chromStart>={start} and chromEnd <= {end}').copy() + sel = anno_df.query( + f'chrom=="{chrom}" and chromStart>={start} and chromEnd <= {end}' + ).copy() sel.chromStart -= start sel.chromEnd -= start - sel.chrom += '_part' + sel.chrom += "_part" subset_df.append(sel) - pd.write_csv(pd.concat(subset_df).rename({'chrom': '#chrom'}, axis=1), out_fn, seq='\t') + pd.write_csv( + pd.concat(subset_df).rename({"chrom": "#chrom"}, axis=1), out_fn, seq="\t" + ) -if __name__ == '__main__': +if __name__ == "__main__": exit(main()) diff --git a/tests/data_import_test.py b/tests/data_import_test.py index 34ba50e..ebb70e6 100644 --- a/tests/data_import_test.py +++ b/tests/data_import_test.py @@ -3,122 +3,185 @@ from isotools.transcriptome import Transcriptome from isotools._utils import splice_identical import logging -logger = logging.getLogger('isotools') + +logger = logging.getLogger("isotools") logger.setLevel(logging.INFO) @pytest.mark.dependency() def test_import_gff(): - isoseq = Transcriptome.from_reference('tests/data/example.gff.gz') - assert len(isoseq) == 65, 'we expect 65 genes' - isoseq.save_reference('tests/data/example_ref_isotools.pkl') + isoseq = Transcriptome.from_reference("tests/data/example.gff.gz") + assert len(isoseq) == 65, "we expect 65 genes" + isoseq.save_reference("tests/data/example_ref_isotools.pkl") assert True -@pytest.mark.dependency(depends=['test_import_gff']) +@pytest.mark.dependency(depends=["test_import_gff"]) def test_import_bam(): - isoseq = Transcriptome.from_reference('tests/data/example_ref_isotools.pkl') - assert isoseq.n_transcripts == 0, 'there should not be any transcripts' - for sample in ('CTL', 'VPA'): - isoseq.add_sample_from_bam(f'tests/data/example_1_{sample}.bam', sample_name=sample, group=sample, platform='SequelII') + isoseq = Transcriptome.from_reference("tests/data/example_ref_isotools.pkl") + assert isoseq.n_transcripts == 0, "there should not be any transcripts" + for sample in ("CTL", "VPA"): + isoseq.add_sample_from_bam( + f"tests/data/example_1_{sample}.bam", + sample_name=sample, + group=sample, + platform="SequelII", + ) # assert isoseq.n_transcripts == 185, 'we expect 185 transcripts' - isoseq.add_qc_metrics('tests/data/example.fa') - isoseq.add_orf_prediction('tests/data/example.fa') - isoseq.save('tests/data/example_1_isotools.pkl') + isoseq.add_qc_metrics("tests/data/example.fa") + isoseq.add_orf_prediction("tests/data/example.fa") + isoseq.save("tests/data/example_1_isotools.pkl") -@pytest.mark.dependency(depends=['test_import_bam']) +@pytest.mark.dependency(depends=["test_import_bam"]) def test_fsm(): - isoseq = Transcriptome.load('tests/data/example_1_isotools.pkl') + isoseq = Transcriptome.load("tests/data/example_1_isotools.pkl") count = 0 - for gene, _, transcript in isoseq.iter_transcripts(query='FSM'): - assert transcript['annotation'][0] == 0 + for gene, _, transcript in isoseq.iter_transcripts(query="FSM"): + assert transcript["annotation"][0] == 0 count += 1 - for ref_id in transcript['annotation'][1]['FSM']: - assert splice_identical(transcript['exons'], gene.ref_transcripts[ref_id]['exons']) - assert count == 22, 'expected 22 FSM transcripts' + for ref_id in transcript["annotation"][1]["FSM"]: + assert splice_identical( + transcript["exons"], gene.ref_transcripts[ref_id]["exons"] + ) + assert count == 22, "expected 22 FSM transcripts" -@pytest.mark.dependency(depends=['test_import_bam']) +@pytest.mark.dependency(depends=["test_import_bam"]) def test_import_csv_reconstruct(): # reconstruct gene structure from scratch - isoseq = Transcriptome.load('tests/data/example_1_isotools.pkl') + isoseq = Transcriptome.load("tests/data/example_1_isotools.pkl") cov_tab = isoseq.transcript_table(coverage=True) - cov_tab.to_csv('tests/data/example_1_cov.csv') - isoseq.write_gtf('tests/data/example_1.gtf') - isoseq_csv = Transcriptome.from_reference('tests/data/example_ref_isotools.pkl') - isoseq_csv._add_novel_gene('nix', 10, 20, '-', {'exons': [10, 20]}) # additional gene should not confuse/break the function - id_map = isoseq_csv.add_sample_from_csv('tests/data/example_1_cov.csv', 'tests/data/example_1.gtf', - reconstruct_genes=True, sample_properties=isoseq.sample_table, sep=',') - remapped_genes = {gid: gid2 for gid2, id_dict in id_map.items() for gid in id_dict.values()} - logger.info('remapped %s transcripts', sum(len(d) for d in id_map)) - assert set(isoseq.samples) == set(isoseq_csv.samples), 'discrepant samples after csv import' - stab1, stab2 = isoseq.sample_table.set_index('name'), isoseq_csv.sample_table.set_index('name') + cov_tab.to_csv("tests/data/example_1_cov.csv") + isoseq.write_gtf("tests/data/example_1.gtf") + isoseq_csv = Transcriptome.from_reference("tests/data/example_ref_isotools.pkl") + isoseq_csv._add_novel_gene( + "nix", 10, 20, "-", {"exons": [10, 20]} + ) # additional gene should not confuse/break the function + id_map = isoseq_csv.add_sample_from_csv( + "tests/data/example_1_cov.csv", + "tests/data/example_1.gtf", + reconstruct_genes=True, + sample_properties=isoseq.sample_table, + sep=",", + ) + remapped_genes = { + gid: gid2 for gid2, id_dict in id_map.items() for gid in id_dict.values() + } + logger.info("remapped %s transcripts", sum(len(d) for d in id_map)) + assert set(isoseq.samples) == set( + isoseq_csv.samples + ), "discrepant samples after csv import" + stab1, stab2 = isoseq.sample_table.set_index( + "name" + ), isoseq_csv.sample_table.set_index("name") for sample in isoseq.samples: - assert stab1.loc[sample, 'group'] == stab2.loc[sample, 'group'], 'wrong group after csv import for sample %s' % sample - assert stab1.loc[sample, 'nonchimeric_reads'] == stab2.loc[sample, 'nonchimeric_reads'], 'wrong number of reads after csv import for sample %s' % sample + assert stab1.loc[sample, "group"] == stab2.loc[sample, "group"], ( + "wrong group after csv import for sample %s" % sample + ) + assert ( + stab1.loc[sample, "nonchimeric_reads"] + == stab2.loc[sample, "nonchimeric_reads"] + ), ("wrong number of reads after csv import for sample %s" % sample) discrepancy = False - for gene in isoseq.iter_genes(query='EXPRESSED'): - if (gene.is_annotated and gene.id in remapped_genes) or (gene.id not in isoseq_csv and gene.id not in remapped_genes): - logger.error('gene missing/renamed after csv import: %s' % str(gene)) + for gene in isoseq.iter_genes(query="EXPRESSED"): + if (gene.is_annotated and gene.id in remapped_genes) or ( + gene.id not in isoseq_csv and gene.id not in remapped_genes + ): + logger.error("gene missing/renamed after csv import: %s" % str(gene)) discrepancy = True - for gene_csv in isoseq_csv.iter_genes(query='EXPRESSED'): + for gene_csv in isoseq_csv.iter_genes(query="EXPRESSED"): if not gene_csv.is_annotated and gene_csv.id in remapped_genes: gene_id = remapped_genes[gene_csv.id] else: gene_id = gene_csv.id gene = isoseq[gene_id] if len(gene.transcripts) != len(gene_csv.transcripts): - logger.error('number of transcripts for %s changed after csv import: %s != %s', gene.id, len(gene.transcripts), len(gene_csv.transcripts)) + logger.error( + "number of transcripts for %s changed after csv import: %s != %s", + gene.id, + len(gene.transcripts), + len(gene_csv.transcripts), + ) discrepancy = True - assert not discrepancy, 'discrepancy found after csv import' + assert not discrepancy, "discrepancy found after csv import" -@pytest.mark.dependency(depends=['test_import_bam']) +@pytest.mark.dependency(depends=["test_import_bam"]) def test_import_csv(): # use gene structure from gtf - isoseq = Transcriptome.load('tests/data/example_1_isotools.pkl') + isoseq = Transcriptome.load("tests/data/example_1_isotools.pkl") cov_tab = isoseq.transcript_table(coverage=True) - cov_tab.to_csv('tests/data/example_1_cov.csv') - isoseq.write_gtf('tests/data/example_1.gtf') - isoseq_csv = Transcriptome.from_reference('tests/data/example_ref_isotools.pkl') - isoseq_csv._add_novel_gene('nix', 10, 20, '-', {'exons': [10, 20]}) # make it a little harder - id_map = isoseq_csv.add_sample_from_csv('tests/data/example_1_cov.csv', 'tests/data/example_1.gtf', - reconstruct_genes=True, sample_properties=isoseq.sample_table, sep=',') + cov_tab.to_csv("tests/data/example_1_cov.csv") + isoseq.write_gtf("tests/data/example_1.gtf") + isoseq_csv = Transcriptome.from_reference("tests/data/example_ref_isotools.pkl") + isoseq_csv._add_novel_gene( + "nix", 10, 20, "-", {"exons": [10, 20]} + ) # make it a little harder + id_map = isoseq_csv.add_sample_from_csv( + "tests/data/example_1_cov.csv", + "tests/data/example_1.gtf", + reconstruct_genes=True, + sample_properties=isoseq.sample_table, + sep=",", + ) remapped_genes = {gid: k for k, v in id_map.items() for gid in v.values()} - logger.info('remapped %s genes', len(id_map)) - assert set(isoseq.samples) == set(isoseq_csv.samples), 'discrepant samples after csv import' - stab1, stab2 = isoseq.sample_table.set_index('name'), isoseq_csv.sample_table.set_index('name') + logger.info("remapped %s genes", len(id_map)) + assert set(isoseq.samples) == set( + isoseq_csv.samples + ), "discrepant samples after csv import" + stab1, stab2 = isoseq.sample_table.set_index( + "name" + ), isoseq_csv.sample_table.set_index("name") for sample in isoseq.samples: - assert stab1.loc[sample, 'group'] == stab2.loc[sample, 'group'], 'wrong group after csv import for sample %s' % sample - assert stab1.loc[sample, 'nonchimeric_reads'] == stab2.loc[sample, 'nonchimeric_reads'], 'wrong number of reads after csv import for sample %s' % sample + assert stab1.loc[sample, "group"] == stab2.loc[sample, "group"], ( + "wrong group after csv import for sample %s" % sample + ) + assert ( + stab1.loc[sample, "nonchimeric_reads"] + == stab2.loc[sample, "nonchimeric_reads"] + ), ("wrong number of reads after csv import for sample %s" % sample) discrepancy = False - for gene in isoseq.iter_genes(query='EXPRESSED'): - if (gene.is_annotated and gene.id in remapped_genes) or (gene.id not in isoseq_csv and gene.id not in remapped_genes): - logger.error('gene missing/renamed after csv import: %s' % str(gene)) + for gene in isoseq.iter_genes(query="EXPRESSED"): + if (gene.is_annotated and gene.id in remapped_genes) or ( + gene.id not in isoseq_csv and gene.id not in remapped_genes + ): + logger.error("gene missing/renamed after csv import: %s" % str(gene)) discrepancy = True - for gene_csv in isoseq_csv.iter_genes(query='EXPRESSED'): + for gene_csv in isoseq_csv.iter_genes(query="EXPRESSED"): if not gene_csv.is_annotated and gene_csv.id in remapped_genes: gene_id = remapped_genes[gene_csv.id] else: gene_id = gene_csv.id gene = isoseq[gene_id] if len(gene.transcripts) != len(gene_csv.transcripts): - logger.error('number of transcripts for %s changed after csv import: %s != %s', gene.id, len(gene.transcripts), len(gene_csv.transcripts)) + logger.error( + "number of transcripts for %s changed after csv import: %s != %s", + gene.id, + len(gene.transcripts), + len(gene_csv.transcripts), + ) discrepancy = True - assert not discrepancy, 'discrepancy found after csv import' + assert not discrepancy, "discrepancy found after csv import" -@pytest.mark.dependency(depends=['test_import_gff']) +@pytest.mark.dependency(depends=["test_import_gff"]) def test_orf(): - total, same = {'+': 0, '-': 0}, {'+': 0, '-': 0} - isoseq = Transcriptome.from_reference('tests/data/example_ref_isotools.pkl') - with FastaFile('tests/data/example.fa') as genome_fh: + total, same = {"+": 0, "-": 0}, {"+": 0, "-": 0} + isoseq = Transcriptome.from_reference("tests/data/example_ref_isotools.pkl") + with FastaFile("tests/data/example.fa") as genome_fh: for gene in isoseq: - gene.add_orfs(genome_fh=genome_fh, reference=True) + gene.add_orfs(genome_fh=genome_fh, reference=True) for transcript in gene.ref_transcripts: - if transcript['transcript_type'] == 'protein_coding' and 'CDS' in transcript: + if ( + transcript["transcript_type"] == "protein_coding" + and "CDS" in transcript + ): total[gene.strand] += 1 - if transcript['CDS'] == transcript['ORF'][:2]: + if transcript["CDS"] == transcript["ORF"][:2]: same[gene.strand] += 1 - assert same["+"]/total["+"] > .9, "at least 90% protein coding transcripts CDS on + should match longest ORF." - assert same["-"]/total["-"] > .9, "at least 90% protein coding transcripts CDS on - should match longest ORF." + assert ( + same["+"] / total["+"] > 0.9 + ), "at least 90% protein coding transcripts CDS on + should match longest ORF." + assert ( + same["-"] / total["-"] > 0.9 + ), "at least 90% protein coding transcripts CDS on - should match longest ORF." diff --git a/tests/diffsplice_test.py b/tests/diffsplice_test.py index bdf8bae..069a348 100644 --- a/tests/diffsplice_test.py +++ b/tests/diffsplice_test.py @@ -2,11 +2,14 @@ def test_diffsplice(): - isoseq = Transcriptome.load('tests/data/example_1_isotools.pkl') - res = isoseq.altsplice_test(groups=isoseq.groups()).sort_values('pvalue') - assert sum(res.padj < .1) == 1, 'expected exactly one significant event' + isoseq = Transcriptome.load("tests/data/example_1_isotools.pkl") + res = isoseq.altsplice_test(groups=isoseq.groups()).sort_values("pvalue") + assert sum(res.padj < 0.1) == 1, "expected exactly one significant event" best = res.iloc[0] - assert best.gene == 'SLC39A14', 'Best differential splicing should be in SLC39A14' - assert (best.start, best.end) == (408496, 414779), 'Genomic coordinates do not match expectations.' - assert best.splice_type == 'ME', 'Splice type does not match expectations.' - assert best.padj < .001, 'Event should be more significant.' + assert best.gene == "SLC39A14", "Best differential splicing should be in SLC39A14" + assert (best.start, best.end) == ( + 408496, + 414779, + ), "Genomic coordinates do not match expectations." + assert best.splice_type == "ME", "Splice type does not match expectations." + assert best.padj < 0.001, "Event should be more significant." diff --git a/tests/domain_test.py b/tests/domain_test.py index a641781..5359213 100644 --- a/tests/domain_test.py +++ b/tests/domain_test.py @@ -3,12 +3,20 @@ def test_anno_domains(): - isoseq = Transcriptome.load('tests/data/example_1_isotools.pkl') - isoseq.add_annotation_domains('tests/data/example_anno_domains.csv', category='domains', progress_bar=False) - gene = isoseq['FN1'] - ref_tr = next(transcript for transcript in gene.ref_transcripts if transcript['transcript_name'] == 'FN1-207') - dom = ref_tr['domain']['annotation'] - assert len(dom) == 25, 'expected 25 annotation domains in FN1-207, but found {len(dom)}' + isoseq = Transcriptome.load("tests/data/example_1_isotools.pkl") + isoseq.add_annotation_domains( + "tests/data/example_anno_domains.csv", category="domains", progress_bar=False + ) + gene = isoseq["FN1"] + ref_tr = next( + transcript + for transcript in gene.ref_transcripts + if transcript["transcript_name"] == "FN1-207" + ) + dom = ref_tr["domain"]["annotation"] + assert ( + len(dom) == 25 + ), "expected 25 annotation domains in FN1-207, but found {len(dom)}" - diffexpr = isoseq.altsplice_test(groups=isoseq.groups()).sort_values('pvalue') - diffexpr = add_domains_to_table(diffexpr, isoseq, 'annotation', insert_after='nmdB') + diffexpr = isoseq.altsplice_test(groups=isoseq.groups()).sort_values("pvalue") + diffexpr = add_domains_to_table(diffexpr, isoseq, "annotation", insert_after="nmdB") diff --git a/tests/splice_graph_test.py b/tests/splice_graph_test.py index 75390e3..54a11a7 100644 --- a/tests/splice_graph_test.py +++ b/tests/splice_graph_test.py @@ -1,36 +1,48 @@ import pytest from isotools.transcriptome import Transcriptome -from isotools._utils import _find_splice_sites, _get_overlap, _get_exonic_region, pairwise +from isotools._utils import ( + _find_splice_sites, + _get_overlap, + _get_exonic_region, + pairwise, +) # @pytest.mark.dependency(depends=['test_import_bam']) def test_import_find_splice_site(): - isoseq = Transcriptome.load('tests/data/example_1_isotools.pkl') - for gene, _, transcript in isoseq.iter_transcripts(query='not NOVEL_GENE'): - sj = [(exon1[1], exon2[0]) for exon1, exon2 in pairwise(transcript['exons'])] + isoseq = Transcriptome.load("tests/data/example_1_isotools.pkl") + for gene, _, transcript in isoseq.iter_transcripts(query="not NOVEL_GENE"): + sj = [(exon1[1], exon2[0]) for exon1, exon2 in pairwise(transcript["exons"])] c1 = gene.ref_segment_graph.find_splice_sites(sj) c2 = _find_splice_sites(sj, gene.ref_transcripts) - assert all(c1 == c2), 'isotools._transcriptome_io._find_splice_sites and Segment_Graph.find_splice_sites yield different results' + assert all( + c1 == c2 + ), "isotools._transcriptome_io._find_splice_sites and Segment_Graph.find_splice_sites yield different results" @pytest.mark.dependency() def test_exon_regions(): - isoseq = Transcriptome.load('tests/data/example_1_isotools.pkl') - for gene in isoseq.iter_genes(query='not NOVEL_GENE'): + isoseq = Transcriptome.load("tests/data/example_1_isotools.pkl") + for gene in isoseq.iter_genes(query="not NOVEL_GENE"): c1 = gene.ref_segment_graph.get_exonic_region() c2 = _get_exonic_region(gene.ref_transcripts) - assert len(c1) == len(c2), 'isotools._transcriptome_io._get_exonic_region and Segment_Graph.get_exonic_region yield different length' - assert all(reg1[0] == reg2[0] and reg1[1] == reg2[1] for reg1, reg2 in zip(c1, c2)), \ - 'isotools._transcriptome_io._get_exonic_region and Segment_Graph.get_exonic_region yield different regions' + assert len(c1) == len( + c2 + ), "isotools._transcriptome_io._get_exonic_region and Segment_Graph.get_exonic_region yield different length" + assert all( + reg1[0] == reg2[0] and reg1[1] == reg2[1] for reg1, reg2 in zip(c1, c2) + ), "isotools._transcriptome_io._get_exonic_region and Segment_Graph.get_exonic_region yield different regions" assert True -@pytest.mark.dependency(depends=['test_exon_regions']) +@pytest.mark.dependency(depends=["test_exon_regions"]) def test_import_exonic_overlap(): - isoseq = Transcriptome.load('tests/data/example_1_isotools.pkl') - for gene, _, transcript in isoseq.iter_transcripts(query='not NOVEL_GENE'): - c1 = gene.ref_segment_graph.get_overlap(transcript['exons'])[0] - c2 = _get_overlap(transcript['exons'], gene.ref_transcripts) - assert c1 == c2, 'isotools._transcriptome_io._get_overlap and Segment_Graph.get_overlap yield different results' + isoseq = Transcriptome.load("tests/data/example_1_isotools.pkl") + for gene, _, transcript in isoseq.iter_transcripts(query="not NOVEL_GENE"): + c1 = gene.ref_segment_graph.get_overlap(transcript["exons"])[0] + c2 = _get_overlap(transcript["exons"], gene.ref_transcripts) + assert ( + c1 == c2 + ), "isotools._transcriptome_io._get_overlap and Segment_Graph.get_overlap yield different results" assert True diff --git a/tox.ini b/tox.ini index 4d75d51..06f0222 100644 --- a/tox.ini +++ b/tox.ini @@ -1,14 +1,13 @@ [tox] -minversion = 3.8.0 -envlist = py36, py37, py38, py39, py310, flake8 +minversion = 3.10.0 +envlist = py310, py311, py312, flake8 isolated_build = true [gh-actions] python = - 3.7: py37 - 3.8: py38 - 3.9: py39 - 3.10: py310, flake8 + 3.10: py310 + 3.11: py311 + 3.12: py312, flake8 [testenv] setenv = @@ -19,6 +18,11 @@ commands = pytest --basetemp={envtmpdir} [testenv:flake8] -basepython = python3.10 -deps = flake8 -commands = flake8 src tests +basepython = python3.12 +deps = + flake8 + flake8-bugbear + black +commands = + black src tests + flake8 --config=setup.cfg src tests \ No newline at end of file