Skip to content

Commit

Permalink
Anopheles refactor part 5 - AnophelesSnpData (#393)
Browse files Browse the repository at this point in the history
* start refactoring snp data functions

* fix typing

* comments

* fix logic error

* comments

* site_filters

* snp_sites

* fix errors

* snp_genotypes, is_accessible

* snp_variants

* site_annotations, snp_calls

* snp_allele_counts; plot_snps

* fix typing

* fix regression

* fix bug

* wip sample_indices

* parameter validation

* parameter validation

* add typeguard to dependencies

* use typeguard, simplify types

* fix typing bug

* fix bugs

* add more typechecked annotations; tighten up region parsing

* fix typing bug

* further typing improvements

* fix typing bug

* more type hints

* fix typing

* fix typing

* tweaks

* fix typing

* fix typing

* fix bokeh type

* check notebooks

* disable typeguard because leaking memory

* fix test failures

* strip typeguard annotations

* fix typing errors

* fix typing error

* squash bugs

* add typeguard to fast tests on ci

* squash bugs

* fix snafu

* ignore notebooks output

* home-rolled type check decorator

* improve message

* depend on typeguard

* wip simulate genotypes

* wip simulate genotypes

* wip simulate genotypes

* wip test_snp_data

* wip test_snp_data

* wip test_snp_data - open_snp_sites

* wip test_snp_data

* wip refactor tests

* squashed commits

* update poetry

* deal with runs of Ns better in plot_snps
  • Loading branch information
alimanfoo authored May 18, 2023
1 parent 61bb489 commit 77cd67b
Show file tree
Hide file tree
Showing 24 changed files with 4,789 additions and 3,920 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ jobs:
- name: Install dependencies
run: poetry install

# Run a subset of tests first which run quickly to fail fast.
# Run a subset of tests first which run quickly without accessing
# any remote data in order to fail fast where possible.
- name: Run fast unit tests
run: poetry run pytest -v tests/anoph
run: poetry run pytest -v tests/anoph --typeguard-packages=malariagen_data,malariagen_data.anoph

- name: Restore GCS cache
uses: actions/cache/restore@v3
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ tests/anoph/fixture/simulated
*~

# nbconvert outputs
*.nbconvert.ipynb
notebooks/*.nbconvert.ipynb
notebooks/*.html
4 changes: 1 addition & 3 deletions malariagen_data/af1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
G123_CALIBRATION_CACHE_NAME = "af1_g123_calibration_v1"
H1X_GWSS_CACHE_NAME = "af1_h1x_gwss_v1"
IHS_GWSS_CACHE_NAME = "af1_ihs_gwss_v1"
DEFAULT_SITE_MASK = "funestus"


class Af1(AnophelesDataResource):
Expand Down Expand Up @@ -83,8 +82,6 @@ class Af1(AnophelesDataResource):
_g123_calibration_cache_name = G123_CALIBRATION_CACHE_NAME
_h1x_gwss_cache_name = H1X_GWSS_CACHE_NAME
_ihs_gwss_cache_name = IHS_GWSS_CACHE_NAME
site_mask_ids = ("funestus",)
_default_site_mask = DEFAULT_SITE_MASK
phasing_analysis_ids = ("funestus",)
_default_phasing_analysis = "funestus"

Expand All @@ -109,6 +106,7 @@ def __init__(
aim_analysis=None,
aim_metadata_dtype=None,
site_filters_analysis=site_filters_analysis,
default_site_mask="funestus",
bokeh_output_notebook=bokeh_output_notebook,
results_cache=results_cache,
log=log,
Expand Down
78 changes: 39 additions & 39 deletions malariagen_data/ag3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from bisect import bisect_left, bisect_right
from textwrap import dedent
from typing import List

import dask
import dask.array as da
Expand All @@ -13,13 +14,17 @@

import malariagen_data # used for .__version__

from .anoph.base import base_params
from .anopheles import AnophelesDataResource, gplt_params
from .util import (
DIM_SAMPLE,
DIM_VARIANT,
Region,
da_from_zarr,
init_zarr_store,
parse_multi_region,
parse_single_region,
region_str,
simple_xarray_concat,
)

Expand All @@ -43,7 +48,6 @@
G123_GWSS_CACHE_NAME = "ag3_g123_gwss_v1"
H1X_GWSS_CACHE_NAME = "ag3_h1x_gwss_v1"
IHS_GWSS_CACHE_NAME = "ag3_ihs_gwss_v1"
DEFAULT_SITE_MASK = "gamb_colu_arab"


class Ag3(AnophelesDataResource):
Expand Down Expand Up @@ -113,8 +117,6 @@ class Ag3(AnophelesDataResource):
_g123_calibration_cache_name = G123_CALIBRATION_CACHE_NAME
_h1x_gwss_cache_name = H1X_GWSS_CACHE_NAME
_ihs_gwss_cache_name = IHS_GWSS_CACHE_NAME
site_mask_ids = ("gamb_colu_arab", "gamb_colu", "arab")
_default_site_mask = DEFAULT_SITE_MASK
phasing_analysis_ids = ("gamb_colu_arab", "gamb_colu", "arab")
_default_phasing_analysis = "gamb_colu_arab"

Expand Down Expand Up @@ -147,6 +149,7 @@ def __init__(
"aim_species": object,
},
site_filters_analysis=site_filters_analysis,
default_site_mask="gamb_colu_arab",
bokeh_output_notebook=bokeh_output_notebook,
results_cache=results_cache,
log=log,
Expand Down Expand Up @@ -675,7 +678,7 @@ def _cnv_hmm_dataset(self, *, contig, sample_set, inline_array, chunks):

def cnv_hmm(
self,
region,
region: base_params.region,
sample_sets=None,
sample_query=None,
max_coverage_variance=DEFAULT_MAX_COVERAGE_VARIANCE,
Expand Down Expand Up @@ -717,13 +720,12 @@ def cnv_hmm(

debug("normalise parameters")
sample_sets = self._prep_sample_sets_param(sample_sets=sample_sets)
region = self.resolve_region(region)
if isinstance(region, Region):
region = [region]
regions: List[Region] = parse_multi_region(self, region)
del region

debug("access CNV HMM data and concatenate as needed")
lx = []
for r in region:
for r in regions:
ly = []
for s in sample_sets:
y = self._cnv_hmm_dataset(
Expand Down Expand Up @@ -898,7 +900,7 @@ def _cnv_coverage_calls_dataset(

def cnv_coverage_calls(
self,
region,
region: base_params.region,
sample_set,
analysis,
inline_array=True,
Expand Down Expand Up @@ -937,13 +939,12 @@ def cnv_coverage_calls(
# calling is done independently in different sample sets.

debug("normalise parameters")
region = self.resolve_region(region)
if isinstance(region, Region):
region = [region]
regions: List[Region] = parse_multi_region(self, region)
del region

debug("access data and concatenate as needed")
lx = []
for r in region:
for r in regions:
debug("obtain coverage calls for the contig")
x = self._cnv_coverage_calls_dataset(
contig=r.contig,
Expand Down Expand Up @@ -1131,7 +1132,7 @@ def cnv_discordant_read_calls(

def gene_cnv(
self,
region,
region: base_params.region,
sample_sets=None,
sample_query=None,
max_coverage_variance=DEFAULT_MAX_COVERAGE_VARIANCE,
Expand Down Expand Up @@ -1163,9 +1164,8 @@ def gene_cnv(
"""

region = self.resolve_region(region)
if isinstance(region, Region):
region = [region]
regions: List[Region] = parse_multi_region(self, region)
del region

ds = simple_xarray_concat(
[
Expand All @@ -1175,7 +1175,7 @@ def gene_cnv(
sample_query=sample_query,
max_coverage_variance=max_coverage_variance,
)
for r in region
for r in regions
],
dim="genes",
)
Expand Down Expand Up @@ -1267,7 +1267,7 @@ def _gene_cnv(self, *, region, sample_sets, sample_query, max_coverage_variance)

def gene_cnv_frequencies(
self,
region,
region: base_params.region,
cohorts,
sample_query=None,
min_cohort_size=10,
Expand Down Expand Up @@ -1318,9 +1318,8 @@ def gene_cnv_frequencies(

debug("check and normalise parameters")
self._check_param_min_cohort_size(min_cohort_size)
region = self.resolve_region(region)
if isinstance(region, Region):
region = [region]
regions: List[Region] = parse_multi_region(self, region)
del region

debug("access and concatenate data from regions")
df = pd.concat(
Expand All @@ -1334,13 +1333,13 @@ def gene_cnv_frequencies(
drop_invariant=drop_invariant,
max_coverage_variance=max_coverage_variance,
)
for r in region
for r in regions
],
axis=0,
)

debug("add metadata")
title = f"Gene CNV frequencies ({self._region_str(region)})"
title = f"Gene CNV frequencies ({region_str(regions)})"
df.attrs["title"] = title

return df
Expand Down Expand Up @@ -1490,7 +1489,7 @@ def _gene_cnv_frequencies(

def gene_cnv_frequencies_advanced(
self,
region,
region: base_params.region,
area_by,
period_by,
sample_sets=None,
Expand Down Expand Up @@ -1553,9 +1552,8 @@ def gene_cnv_frequencies_advanced(

self._check_param_min_cohort_size(min_cohort_size)

region = self.resolve_region(region)
if isinstance(region, Region):
region = [region]
regions: List[Region] = parse_multi_region(self, region)
del region

ds = simple_xarray_concat(
[
Expand All @@ -1571,12 +1569,12 @@ def gene_cnv_frequencies_advanced(
max_coverage_variance=max_coverage_variance,
ci_method=ci_method,
)
for r in region
for r in regions
],
dim="variants",
)

title = f"Gene CNV frequencies ({self._region_str(region)})"
title = f"Gene CNV frequencies ({region_str(regions)})"
ds.attrs["title"] = title

return ds
Expand Down Expand Up @@ -1740,7 +1738,7 @@ def _gene_cnv_frequencies_advanced(
def plot_cnv_hmm_coverage_track(
self,
sample,
region,
region: base_params.single_region,
sample_set=None,
y_max="auto",
sizing_mode=gplt_params.sizing_mode_default,
Expand Down Expand Up @@ -1777,7 +1775,7 @@ def plot_cnv_hmm_coverage_track(
Passed through to bokeh line() function.
show : bool, optional
If true, show the plot.
x_range : bokeh.models.Range1d, optional
x_range : bokeh.models.Range, optional
X axis range (for linking to other tracks).
Returns
Expand All @@ -1792,7 +1790,8 @@ def plot_cnv_hmm_coverage_track(
import bokeh.plotting as bkplt

debug("resolve region")
region = self.resolve_region(region)
region_prepped: Region = parse_single_region(self, region)
del region

debug("access sample metadata, look up sample")
sample_rec = self._lookup_sample(sample=sample, sample_set=sample_set)
Expand All @@ -1801,7 +1800,7 @@ def plot_cnv_hmm_coverage_track(

debug("access HMM data")
hmm = self.cnv_hmm(
region=region, sample_sets=sample_set, max_coverage_variance=None
region=region_prepped, sample_sets=sample_set, max_coverage_variance=None
)

debug("select data for the given sample")
Expand Down Expand Up @@ -1863,7 +1862,7 @@ def plot_cnv_hmm_coverage_track(
debug("tidy up the plot")
fig.yaxis.axis_label = "Copy number"
fig.yaxis.ticker = list(range(y_max + 1))
self._bokeh_style_genome_xaxis(fig, region.contig)
self._bokeh_style_genome_xaxis(fig, region_prepped.contig)
fig.add_layout(fig.legend[0], "right")

if show:
Expand Down Expand Up @@ -1967,7 +1966,7 @@ def plot_cnv_hmm_coverage(

def plot_cnv_hmm_heatmap_track(
self,
region,
region: base_params.single_region,
sample_sets=None,
sample_query=None,
max_coverage_variance=DEFAULT_MAX_COVERAGE_VARIANCE,
Expand Down Expand Up @@ -2017,11 +2016,12 @@ def plot_cnv_hmm_heatmap_track(
import bokeh.palettes as bkpal
import bokeh.plotting as bkplt

region = self.resolve_region(region)
region_prepped: Region = parse_single_region(self, region)
del region

debug("access HMM data")
ds_cnv = self.cnv_hmm(
region=region,
region=region_prepped,
sample_sets=sample_sets,
sample_query=sample_query,
max_coverage_variance=max_coverage_variance,
Expand Down Expand Up @@ -2105,7 +2105,7 @@ def plot_cnv_hmm_heatmap_track(

debug("tidy")
fig.yaxis.axis_label = "Samples"
self._bokeh_style_genome_xaxis(fig, region.contig)
self._bokeh_style_genome_xaxis(fig, region_prepped.contig)
fig.yaxis.ticker = bkmod.FixedTicker(
ticks=np.arange(len(sample_id)),
)
Expand Down
Loading

0 comments on commit 77cd67b

Please sign in to comment.