diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 08321701c..a014e65eb 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -7,4 +7,6 @@ jobs: steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 - - uses: pre-commit/action@v2.0.0 + with: + python-version: 3.7 + - uses: pre-commit/action@v2.0.3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f8e63f053..6544d8e7f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: hooks: - id: black - repo: https://github.com/keewis/blackdoc - rev: v0.3.3 + rev: v0.3.4 hooks: - id: blackdoc - repo: https://gitlab.com/pycqa/flake8 diff --git a/README.md b/README.md index f742b591b..a09bc5465 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,12 @@ $ poetry publish ## Release notes +### 0.9.0 + +* Add `Ag3.haplotypes()` and supporting functions `Ag3.open_haplotypes()` + and `Ag3.open_haplotype_sites()`. + + ### 0.8.0 * Add site filter columns to dataframes returned by diff --git a/malariagen_data/ag3.py b/malariagen_data/ag3.py index 3d97077ed..739dfad55 100644 --- a/malariagen_data/ag3.py +++ b/malariagen_data/ag3.py @@ -96,6 +96,8 @@ def __init__(self, url, **kwargs): self._cache_cnv_hmm = dict() self._cache_cnv_coverage_calls = dict() self._cache_cnv_discordant_read_calls = dict() + self._cache_haplotypes = dict() + self._cache_haplotype_sites = dict() @property def releases(self): @@ -347,7 +349,7 @@ def site_filters( field="filter_pass", analysis="dt_20200416", inline_array=True, - chunks="auto", + chunks="native", ): """Access SNP site filters. @@ -400,7 +402,7 @@ def snp_sites( site_mask=None, site_filters="dt_20200416", inline_array=True, - chunks="auto", + chunks="native", ): """Access SNP site data (positions and alleles). @@ -480,7 +482,7 @@ def snp_genotypes( site_mask=None, site_filters="dt_20200416", inline_array=True, - chunks="auto", + chunks="native", ): """Access SNP genotypes and associated data. @@ -548,7 +550,7 @@ def open_genome(self): self._cache_genome = zarr.open_consolidated(store=store) return self._cache_genome - def genome_sequence(self, contig, inline_array=True, chunks="auto"): + def genome_sequence(self, contig, inline_array=True, chunks="native"): """Access the reference genome sequence. Parameters @@ -892,7 +894,7 @@ def site_annotations( site_mask=None, site_filters="dt_20200416", inline_array=True, - chunks="auto", + chunks="native", ): """Load site annotations. @@ -1006,7 +1008,7 @@ def snp_calls( site_mask=None, site_filters="dt_20200416", inline_array=True, - chunks="auto", + chunks="native", ): """Access SNP sites, site filters and genotype calls. @@ -1175,7 +1177,7 @@ def cnv_hmm( contig, sample_sets="v3_wild", inline_array=True, - chunks="auto", + chunks="native", ): """Access CNV HMM data. @@ -1270,7 +1272,7 @@ def cnv_coverage_calls( sample_set, analysis, inline_array=True, - chunks="auto", + chunks="native", ): """Access CNV HMM data. @@ -1471,7 +1473,7 @@ def cnv_discordant_read_calls( contig, sample_sets="v3_wild", inline_array=True, - chunks="auto", + chunks="native", ): """Access CNV discordant read calls data. @@ -1672,6 +1674,195 @@ def gene_cnv_frequencies(self, contig, cohorts=None, sample_sets="v3_wild"): return df + def open_haplotypes(self, sample_set, analysis): + """Open haplotypes zarr. + + Parameters + ---------- + sample_set : str + analysis : {"arab", "gamb_colu", "gamb_colu_arab"} + + Returns + ------- + root : zarr.hierarchy.Group + + """ + try: + return self._cache_haplotypes[(sample_set, analysis)] + except KeyError: + release = self._lookup_release(sample_set=sample_set) + path = f"{self._path}/{release}/snp_haplotypes/{sample_set}/{analysis}/zarr" + store = SafeStore(FSMap(root=path, fs=self._fs, check=False, create=False)) + # some sample sets have no data for a given analysis, handle this + if ".zmetadata" not in store: + root = None + else: + root = zarr.open_consolidated(store=store) + self._cache_haplotypes[(sample_set, analysis)] = root + return root + + def open_haplotype_sites(self, analysis): + """Open haplotype sites zarr. + + Parameters + ---------- + analysis : {"arab", "gamb_colu", "gamb_colu_arab"} + + Returns + ------- + root : zarr.hierarchy.Group + + """ + try: + return self._cache_haplotype_sites[analysis] + except KeyError: + path = f"{self._path}/v3/snp_haplotypes/sites/{analysis}/zarr" + store = SafeStore(FSMap(root=path, fs=self._fs, check=False, create=False)) + root = zarr.open_consolidated(store=store) + self._cache_haplotype_sites[analysis] = root + return root + + def _haplotypes_dataset(self, contig, sample_set, analysis, inline_array, chunks): + + # open zarr + root = self.open_haplotypes(sample_set=sample_set, analysis=analysis) + sites = self.open_haplotype_sites(analysis=analysis) + + # some sample sets have no data for a given analysis, handle this + if root is None: + return None + + coords = dict() + data_vars = dict() + + # variant_position + pos = sites[f"{contig}/variants/POS"] + coords["variant_position"] = ( + [DIM_VARIANT], + from_zarr(pos, inline_array=inline_array, chunks=chunks), + ) + + # variant_contig + contig_index = self.contigs.index(contig) + coords["variant_contig"] = ( + [DIM_VARIANT], + da.full_like(pos, fill_value=contig_index, dtype="u1"), + ) + + # variant_allele + ref = from_zarr( + sites[f"{contig}/variants/REF"], inline_array=inline_array, chunks=chunks + ) + alt = from_zarr( + sites[f"{contig}/variants/ALT"], inline_array=inline_array, chunks=chunks + ) + variant_allele = da.hstack([ref[:, None], alt[:, None]]) + data_vars["variant_allele"] = [DIM_VARIANT, DIM_ALLELE], variant_allele + + # call_genotype + data_vars["call_genotype"] = ( + [DIM_VARIANT, DIM_SAMPLE, DIM_PLOIDY], + from_zarr( + root[f"{contig}/calldata/GT"], inline_array=inline_array, chunks=chunks + ), + ) + + # sample arrays + coords["sample_id"] = ( + [DIM_SAMPLE], + from_zarr(root["samples"], inline_array=inline_array, chunks=chunks), + ) + + # setup attributes + attrs = {"contigs": self.contigs} + + # create a dataset + ds = xarray.Dataset(data_vars=data_vars, coords=coords, attrs=attrs) + + return ds + + def haplotypes( + self, + contig, + analysis, + sample_sets="v3_wild", + inline_array=True, + chunks="native", + ): + """Access haplotype data. + + Parameters + ---------- + contig : str + Chromosome arm, e.g., "3R". + analysis : {"arab", "gamb_colu", "gamb_colu_arab"} + Which phasing analysis to use. If analysing only An. arabiensis, the "arab" analysis + is best. If analysing only An. gambiae and An. coluzzii, the "gamb_colu" analysis is + best. Otherwise use the "gamb_colu_arab" analysis. + sample_sets : str or list of str + Can be a sample set identifier (e.g., "AG1000G-AO") or a list of sample set + identifiers (e.g., ["AG1000G-BF-A", "AG1000G-BF-B"]) or a release identifier (e.g., + "v3") or a list of release identifiers. + inline_array : bool, optional + Passed through to dask.array.from_array(). + chunks : str, optional + If 'auto' let dask decide chunk size. If 'native' use native zarr chunks. + Also can be a target size, e.g., '200 MiB'. + + Returns + ------- + ds : xarray.Dataset + + """ + + sample_sets = self._prep_sample_sets_arg(sample_sets=sample_sets) + + if isinstance(sample_sets, str): + + # single sample set requested + ds = self._haplotypes_dataset( + contig=contig, + sample_set=sample_sets, + analysis=analysis, + inline_array=inline_array, + chunks=chunks, + ) + + else: + + # multiple sample sets requested, need to concatenate along samples dimension + datasets = [ + self._haplotypes_dataset( + contig=contig, + sample_set=sample_set, + analysis=analysis, + inline_array=inline_array, + chunks=chunks, + ) + for sample_set in sample_sets + ] + # some sample sets have no data for a given analysis, handle this + datasets = [d for d in datasets if d is not None] + if len(datasets) == 0: + ds = None + else: + ds = xarray.concat( + datasets, + dim=DIM_SAMPLE, + data_vars="minimal", + coords="minimal", + compat="override", + join="override", + ) + + # if no samples at all, raise + if ds is None: + raise ValueError( + f"no samples available for analysis {analysis!r} and sample sets {sample_sets!r}" + ) + + return ds + @numba.njit("Tuple((int8, int64))(int8[:], int8)") def _cn_mode_1d(a, vmax): diff --git a/pyproject.toml b/pyproject.toml index 309c75a3a..9082f3075 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "malariagen_data" -version = "0.8.0" +version = "0.9.0" description = "A package for accessing MalariaGEN public data." authors = ["Alistair Miles ", "Chris Clarkson "] license = "MIT" diff --git a/tests/test_ag3.py b/tests/test_ag3.py index 101d331d5..07f03df14 100644 --- a/tests/test_ag3.py +++ b/tests/test_ag3.py @@ -1063,3 +1063,90 @@ def test_gene_cnv_frequencies_0_cohort(contig): _ = ag3.gene_cnv_frequencies( contig=contig, sample_sets="v3_wild", cohorts=cohorts ) + + +@pytest.mark.parametrize( + "sample_sets", ["AG1000G-BF-A", ("AG1000G-TZ", "AG1000G-UG"), "v3", "v3_wild"] +) +@pytest.mark.parametrize("contig", ["2R", "X"]) +@pytest.mark.parametrize("analysis", ["arab", "gamb_colu", "gamb_colu_arab"]) +def test_haplotypes(sample_sets, contig, analysis): + + ag3 = setup_ag3() + + # check expected samples + sample_query = None + if analysis == "arab": + sample_query = "species == 'arabiensis' and sample_set != 'AG1000G-X'" + elif analysis == "gamb_colu": + sample_query = "species in ['gambiae', 'coluzzii', 'intermediate_gambiae_coluzzii'] and sample_set != 'AG1000G-X'" + elif analysis == "gamb_colu_arab": + sample_query = "sample_set != 'AG1000G-X'" + df_samples = ag3.sample_metadata(sample_sets=sample_sets) + expected_samples = set(df_samples.query(sample_query)["sample_id"].tolist()) + n_samples = len(expected_samples) + + # check if any samples + if n_samples == 0: + with pytest.raises(ValueError): + # no samples, raise + ag3.haplotypes(contig=contig, sample_sets=sample_sets, analysis=analysis) + return + + ds = ag3.haplotypes(contig=contig, sample_sets=sample_sets, analysis=analysis) + assert isinstance(ds, xarray.Dataset) + + # check fields + expected_data_vars = { + "variant_allele", + "call_genotype", + } + assert set(ds.data_vars) == expected_data_vars + + expected_coords = { + "variant_contig", + "variant_position", + "sample_id", + } + assert set(ds.coords) == expected_coords + + # check dimensions + assert set(ds.dims) == {"alleles", "ploidy", "samples", "variants"} + + # check samples + samples = set(ds["sample_id"].values) + assert samples == expected_samples + + # check dim lengths + assert ds.dims["samples"] == n_samples + assert ds.dims["ploidy"] == 2 + assert ds.dims["alleles"] == 2 + + # check shapes + for f in expected_coords | expected_data_vars: + x = ds[f] + assert isinstance(x, xarray.DataArray) + assert isinstance(x.data, da.Array) + + if f == "variant_allele": + assert x.ndim, f == 2 + assert x.shape[1] == 2 + assert x.dims == ("variants", "alleles") + elif f.startswith("variant_"): + assert x.ndim, f == 1 + assert x.dims == ("variants",) + elif f == "call_genotype": + assert x.ndim == 3 + assert x.dims == ("variants", "samples", "ploidy") + assert x.shape[1] == n_samples + assert x.shape[2] == 2 + + # check attributes + assert "contigs" in ds.attrs + assert ds.attrs["contigs"] == ("2R", "2L", "3R", "3L", "X") + + # check can setup computations + d1 = ds["variant_position"] > 10_000 + assert isinstance(d1, xarray.DataArray) + d2 = ds["call_genotype"].sum(axis=(1, 2)) + assert isinstance(d2, xarray.DataArray)