Skip to content

Commit

Permalink
Improve performance when accessing biallelic SNP calls (#623)
Browse files Browse the repository at this point in the history
* fix bug with biallelic snp calls and variant_allele

* optimisations, wip

* massive thrash

* comment

* refactor

* deal with strange performance issue in zarr and fsspec

* poetry update

* tweaks

* revert silly mistake

* avoid getattr pickle black hole of doom

* fix gcs bucket

* remove outputs

* tweak test

* tweak test
  • Loading branch information
alimanfoo authored Sep 23, 2024
1 parent 6c3554b commit 9f13fb0
Show file tree
Hide file tree
Showing 9 changed files with 672 additions and 3,650 deletions.
49 changes: 30 additions & 19 deletions malariagen_data/anoph/snp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
DIM_VARIANT,
CacheMiss,
Region,
apply_allele_mapping,
check_types,
da_compress,
da_concat,
da_from_zarr,
dask_apply_allele_mapping,
dask_compress_dataset,
dask_genotype_array_map_alleles,
init_zarr_store,
locate_region,
parse_multi_region,
Expand Down Expand Up @@ -565,6 +566,7 @@ def _snp_variants_for_contig(
ref = da_from_zarr(ref_z, inline_array=inline_array, chunks=chunks)
alt = da_from_zarr(alt_z, inline_array=inline_array, chunks=chunks)
variant_allele = da.concatenate([ref[:, None], alt], axis=1)
variant_allele = variant_allele.rechunk((variant_allele.chunks[0], -1))
data_vars["variant_allele"] = [DIM_VARIANT, DIM_ALLELE], variant_allele

# Set up variant_contig.
Expand Down Expand Up @@ -1611,7 +1613,7 @@ def biallelic_snp_calls(

with self._spinner("Prepare biallelic SNP calls"):
# Subset to biallelic sites.
ds_bi = ds.isel(variants=loc_bi)
ds_bi = dask_compress_dataset(ds, indexer=loc_bi, dim="variants")

# Start building a new dataset.
coords: Dict[str, Any] = dict()
Expand All @@ -1624,42 +1626,50 @@ def biallelic_snp_calls(
coords["variant_contig"] = ("variants",), ds_bi["variant_contig"].data

# Store position.
coords["variant_position"] = ("variants",), ds_bi["variant_position"].data
variant_position = ds_bi["variant_position"].data
coords["variant_position"] = ("variants",), variant_position

# Prepare allele mapping for dask computations.
allele_mapping_zarr = zarr.array(allele_mapping)
allele_mapping_dask = da_from_zarr(
allele_mapping_zarr, chunks="native", inline_array=True
)

# Store alleles, transformed.
variant_allele = ds_bi["variant_allele"].data
variant_allele = variant_allele.rechunk((variant_allele.chunks[0], -1))
variant_allele_out = da.map_blocks(
lambda block: apply_allele_mapping(block, allele_mapping, max_allele=1),
variant_allele,
dtype=variant_allele.dtype,
chunks=(variant_allele.chunks[0], [2]),
variant_allele_dask = ds_bi["variant_allele"].data
variant_allele_out = dask_apply_allele_mapping(
variant_allele_dask, allele_mapping_dask, max_allele=1
)
data_vars["variant_allele"] = ("variants", "alleles"), variant_allele_out

# Store allele counts, transformed, so we don't have to recompute.
ac_out = apply_allele_mapping(ac_bi, allele_mapping, max_allele=1)
# Store allele counts, transformed.
ac_bi_zarr = zarr.array(ac_bi)
ac_bi_dask = da_from_zarr(ac_bi_zarr, chunks="native", inline_array=True)
ac_out = dask_apply_allele_mapping(
ac_bi_dask, allele_mapping_dask, max_allele=1
)
data_vars["variant_allele_count"] = ("variants", "alleles"), ac_out

# Store genotype calls, transformed.
gt = ds_bi["call_genotype"].data
gt_out = allel.GenotypeDaskArray(gt).map_alleles(allele_mapping)
gt_dask = ds_bi["call_genotype"].data
gt_out = dask_genotype_array_map_alleles(gt_dask, allele_mapping_dask)
data_vars["call_genotype"] = (
(
"variants",
"samples",
"ploidy",
),
gt_out.values,
gt_out,
)

# Build dataset.
ds_out = xr.Dataset(coords=coords, data_vars=data_vars, attrs=ds.attrs)

# Apply conditions.
if max_missing_an is not None or min_minor_ac is not None:
ac_out_computed = ac_out.compute()
loc_out = np.ones(ds_out.sizes["variants"], dtype=bool)
an = ac_out.sum(axis=1)
an = ac_out_computed.sum(axis=1)

# Apply missingness condition.
if max_missing_an is not None:
Expand All @@ -1673,20 +1683,21 @@ def biallelic_snp_calls(

# Apply minor allele count condition.
if min_minor_ac is not None:
ac_minor = ac_out.min(axis=1)
ac_minor = ac_out_computed.min(axis=1)
if isinstance(min_minor_ac, float):
ac_minor_frac = ac_minor / an
loc_minor = ac_minor_frac >= min_minor_ac
else:
loc_minor = ac_minor >= min_minor_ac
loc_out &= loc_minor

ds_out = ds_out.isel(variants=loc_out)
# Apply selection from conditions.
ds_out = dask_compress_dataset(ds_out, indexer=loc_out, dim="variants")

# Try to meet target number of SNPs.
if n_snps is not None:
if ds_out.sizes["variants"] > (n_snps * 2):
# Do some thinning.
# Apply thinning.
thin_step = ds_out.sizes["variants"] // n_snps
loc_thin = slice(thin_offset, None, thin_step)
ds_out = ds_out.isel(variants=loc_thin)
Expand Down
179 changes: 132 additions & 47 deletions malariagen_data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import typeguard
import xarray as xr
import zarr # type: ignore

# zarr >= 2.11.0
from zarr.storage import BaseStore # type: ignore
from fsspec.core import url_to_fs # type: ignore
from fsspec.mapping import FSMap # type: ignore
from numpydoc_decorator.impl import humanize_type # type: ignore
Expand Down Expand Up @@ -113,46 +116,40 @@ def unpack_gff3_attributes(df: pd.DataFrame, attributes: Tuple[str, ...]):
return df


# zarr compatibility, version 2.11.0 introduced the BaseStore class
# see also https://github.com/malariagen/malariagen-data-python/issues/129

try:
# zarr >= 2.11.0
from zarr.storage import KVStore # type: ignore

class SafeStore(KVStore):
def __getitem__(self, key):
try:
return self._mutable_mapping[key]
except KeyError as e:
# raise a different error to ensure zarr propagates the exception, rather than filling
raise FileNotFoundError(e)
class SafeStore(BaseStore):
"""This class wraps any zarr store and ensures that missing chunks
will not get automatically filled but will raise an exception. There
should be no missing chunks in any of the datasets we host."""

def __contains__(self, key):
return key in self._mutable_mapping
def __init__(self, store):
self._store = store

except ImportError:
# zarr < 2.11.0
def __getitem__(self, key):
try:
return self._store[key]
except KeyError as e:
# Raise a different error to ensure zarr propagates the exception,
# rather than filling.
raise FileNotFoundError(e)

class SafeStore(Mapping): # type: ignore
def __init__(self, store):
self.store = store
def __getattr__(self, attr):
if attr == "__setstate__":
# Special method called during unpickling, don't pass through.
raise AttributeError(attr)
# Pass through all other attribute access to the wrapped store.
return getattr(self._store, attr)

def __getitem__(self, key):
try:
return self.store[key]
except KeyError as e:
# raise a different error to ensure zarr propagates the exception, rather than filling
raise FileNotFoundError(e)
def __iter__(self):
return iter(self._store)

def __contains__(self, key):
return key in self.store
def __len__(self):
return len(self._store)

def __iter__(self):
return iter(self.store)
def __setitem__(self, item):
raise NotImplementedError

def __len__(self):
return len(self.store)
def __delitem__(self, item):
raise NotImplementedError


class SiteClass(Enum):
Expand Down Expand Up @@ -269,7 +266,11 @@ def da_from_zarr(
dask_chunks = chunks

kwargs = dict(
chunks=dask_chunks, fancy=False, lock=False, inline_array=inline_array
inline_array=inline_array,
chunks=dask_chunks,
fancy=True,
lock=False,
asarray=True,
)
try:
d = da.from_array(z, **kwargs)
Expand Down Expand Up @@ -301,14 +302,19 @@ def dask_compress_dataset(ds, indexer, dim):
indexer = ds[indexer].data

# sanity checks
assert isinstance(indexer, da.Array)
assert indexer.ndim == 1
assert indexer.dtype == bool
assert indexer.shape[0] == ds.sizes[dim]

# temporarily compute the indexer once, to avoid multiple reads from
# the underlying data
indexer_computed = indexer.compute()
if isinstance(indexer, da.Array):
# temporarily compute the indexer once, to avoid multiple reads from
# the underlying data
indexer_computed = indexer.compute()
else:
assert isinstance(indexer, np.ndarray)
indexer_computed = indexer
indexer_zarr = zarr.array(indexer_computed)
indexer = da_from_zarr(indexer_zarr, chunks="native", inline_array=True)

coords = dict()
for k in ds.coords:
Expand Down Expand Up @@ -353,32 +359,36 @@ def da_compress(
):
"""Wrapper for dask.array.compress() which computes chunk sizes faster."""

# sanity checks
# Sanity checks.
assert indexer.ndim == 1
assert indexer.dtype == bool
assert indexer.shape[0] == data.shape[axis]

# useful variables
# Useful variables.
old_chunks = data.chunks
axis_old_chunks = old_chunks[axis]

# load the indexer temporarily for chunk size computations
# Load the indexer temporarily for chunk size computations.
if indexer_computed is None:
indexer_computed = indexer.compute()

# ensure indexer and data are chunked in the same way
# Ensure indexer and data are chunked in the same way.
indexer = indexer.rechunk((axis_old_chunks,))

# apply the indexing operation
# Apply the indexing operation.
v = da.compress(indexer, data, axis=axis)

# need to compute chunks sizes in order to know dimension sizes;
# Need to compute chunks sizes in order to know dimension sizes;
# would normally do v.compute_chunk_sizes() but that is slow for
# multidimensional arrays, so hack something more efficient

# multidimensional arrays, so hack something more efficient.
axis_new_chunks_list = []
slice_start = 0
need_rechunk = False
for old_chunk_size in axis_old_chunks:
slice_stop = slice_start + old_chunk_size
new_chunk_size = np.sum(indexer_computed[slice_start:slice_stop])
new_chunk_size = int(np.sum(indexer_computed[slice_start:slice_stop]))
if new_chunk_size == 0:
need_rechunk = True
axis_new_chunks_list.append(new_chunk_size)
slice_start = slice_stop
axis_new_chunks = tuple(axis_new_chunks_list)
Expand All @@ -387,6 +397,23 @@ def da_compress(
)
v._chunks = new_chunks

# Deal with empty chunks, they break reductions.
# Possibly related to https://github.com/dask/dask/issues/10327
# and https://github.com/dask/dask/issues/2794
if need_rechunk:
axis_new_chunks_nonzero = tuple([x for x in axis_new_chunks if x > 0])
# Edge case, all chunks empty:
if len(axis_new_chunks_nonzero) == 0:
# Not much we can do about this, no data.
axis_new_chunks_nonzero = (0,)
new_chunks_nonzero = tuple(
[
axis_new_chunks_nonzero if i == axis else c
for i, c in enumerate(new_chunks)
]
)
v = v.rechunk(new_chunks_nonzero)

return v


Expand Down Expand Up @@ -1461,6 +1488,64 @@ def apply_allele_mapping(x, mapping, max_allele):
return out


def dask_apply_allele_mapping(v, mapping, max_allele):
assert isinstance(v, da.Array)
assert isinstance(mapping, da.Array)
assert v.ndim == 2
assert mapping.ndim == 2
assert v.shape[0] == mapping.shape[0]
v = v.rechunk((v.chunks[0], -1))
mapping = mapping.rechunk((v.chunks[0], -1))
out = da.map_blocks(
lambda xb, mb: apply_allele_mapping(xb, mb, max_allele=max_allele),
v,
mapping,
dtype=v.dtype,
chunks=(v.chunks[0], [max_allele + 1]),
)
return out


def genotype_array_map_alleles(gt, mapping):
# Transform genotype calls via an allele mapping.
# N.B., scikit-allel does not handle empty blocks well, so we
# include some extra logic to handle that better.
assert isinstance(gt, np.ndarray)
assert isinstance(mapping, np.ndarray)
assert gt.ndim == 3
assert mapping.ndim == 3
assert gt.shape[0] == mapping.shape[0]
assert gt.shape[1] > 0
assert gt.shape[2] == 2
if gt.size > 0:
# Block is not empty, can pass through to GenotypeArray.
assert gt.shape[0] > 0
m = mapping[:, 0, :]
out = allel.GenotypeArray(gt).map_alleles(m).values
else:
# Block is empty so no alleles need to be mapped.
assert gt.shape[0] == 0
out = gt
return out


def dask_genotype_array_map_alleles(gt, mapping):
assert isinstance(gt, da.Array)
assert isinstance(mapping, da.Array)
assert gt.ndim == 3
assert mapping.ndim == 2
assert gt.shape[0] == mapping.shape[0]
mapping = mapping.rechunk((gt.chunks[0], -1))
gt_out = da.map_blocks(
genotype_array_map_alleles,
gt,
mapping[:, None, :],
chunks=gt.chunks,
dtype=gt.dtype,
)
return gt_out


def pandas_apply(f, df, columns):
"""Optimised alternative to pandas apply."""
df = df.reset_index(drop=True)
Expand Down
Loading

0 comments on commit 9f13fb0

Please sign in to comment.