Skip to content

Commit

Permalink
Further work on scalability for large biallelic genotype data computa…
Browse files Browse the repository at this point in the history
…tions (#626)

* back out usage of zarr

* revert to native chunks for now

* revert to native chunks for now
  • Loading branch information
alimanfoo authored Sep 24, 2024
1 parent 9f13fb0 commit b4a3cc9
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 61 deletions.
2 changes: 1 addition & 1 deletion malariagen_data/ag3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .anopheles import AnophelesDataResource

# silence dask performance warnings
dask.config.set(**{"array.slicing.split_large_chunks": False}) # type: ignore
dask.config.set(**{"array.slicing.split_native_chunks": False}) # type: ignore

MAJOR_VERSION_NUMBER = 3
MAJOR_VERSION_PATH = "v3"
Expand Down
4 changes: 0 additions & 4 deletions malariagen_data/anoph/base_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,6 @@ def validate_sample_selection_params(
# amounts of data.
native_chunks: chunks = "native"

# Alternative default chunk size, suitable for functions which need to
# scan a large amount of data.
large_chunks: chunks = "300MiB"

gff_attributes: TypeAlias = Annotated[
Optional[Union[Sequence[str], str]],
"""
Expand Down
2 changes: 1 addition & 1 deletion malariagen_data/anoph/fst.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def fst_gwss(
] = fst_params.max_cohort_size_default,
random_seed: base_params.random_seed = 42,
inline_array: base_params.inline_array = base_params.inline_array_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
clip_min: fst_params.clip_min = 0.0,
) -> Tuple[np.ndarray, np.ndarray]:
# Change this name if you ever change the behaviour of this function, to
Expand Down
10 changes: 5 additions & 5 deletions malariagen_data/anoph/g123.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def g123_gwss(
] = g123_params.max_cohort_size_default,
random_seed: base_params.random_seed = 42,
inline_array: base_params.inline_array = base_params.inline_array_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
) -> Tuple[np.ndarray, np.ndarray]:
# Change this name if you ever change the behaviour of this function, to
# invalidate any previously cached data.
Expand Down Expand Up @@ -264,7 +264,7 @@ def g123_calibration(
window_sizes: g123_params.window_sizes = g123_params.window_sizes_default,
random_seed: base_params.random_seed = 42,
inline_array: base_params.inline_array = base_params.inline_array_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
) -> Mapping[str, np.ndarray]:
# Change this name if you ever change the behaviour of this function, to
# invalidate any previously cached data.
Expand Down Expand Up @@ -323,7 +323,7 @@ def plot_g123_gwss_track(
x_range: Optional[gplt_params.x_range] = None,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
inline_array: base_params.inline_array = base_params.inline_array_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
) -> gplt_params.figure:
# compute G123
x, g123 = self.g123_gwss(
Expand Down Expand Up @@ -424,7 +424,7 @@ def plot_g123_gwss(
show: gplt_params.show = True,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
inline_array: base_params.inline_array = base_params.inline_array_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
) -> gplt_params.figure:
# gwss track
fig1 = self.plot_g123_gwss_track(
Expand Down Expand Up @@ -497,7 +497,7 @@ def plot_g123_calibration(
title: Optional[gplt_params.title] = None,
show: gplt_params.show = True,
inline_array: base_params.inline_array = base_params.inline_array_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
) -> gplt_params.figure:
# get g123 values
calibration_runs = self.g123_calibration(
Expand Down
10 changes: 5 additions & 5 deletions malariagen_data/anoph/h12.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def h12_calibration(
] = h12_params.max_cohort_size_default,
window_sizes: h12_params.window_sizes = h12_params.window_sizes_default,
random_seed: base_params.random_seed = 42,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> Mapping[str, np.ndarray]:
# Change this name if you ever change the behaviour of this function, to
Expand Down Expand Up @@ -143,7 +143,7 @@ def plot_h12_calibration(
random_seed: base_params.random_seed = 42,
title: Optional[str] = None,
show: bool = True,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> gplt_params.figure:
# Get H12 values.
Expand Down Expand Up @@ -286,7 +286,7 @@ def h12_gwss(
base_params.max_cohort_size
] = h12_params.max_cohort_size_default,
random_seed: base_params.random_seed = 42,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> Tuple[np.ndarray, np.ndarray]:
# Change this name if you ever change the behaviour of this function, to
Expand Down Expand Up @@ -346,7 +346,7 @@ def plot_h12_gwss_track(
show: gplt_params.show = True,
x_range: Optional[gplt_params.x_range] = None,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> gplt_params.figure:
# Compute H12.
Expand Down Expand Up @@ -447,7 +447,7 @@ def plot_h12_gwss(
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
show: gplt_params.show = True,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> gplt_params.figure:
# Plot GWSS track.
Expand Down
6 changes: 3 additions & 3 deletions malariagen_data/anoph/h1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def h1x_gwss(
base_params.max_cohort_size
] = h12_params.max_cohort_size_default,
random_seed: base_params.random_seed = 42,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> Tuple[np.ndarray, np.ndarray]:
# Change this name if you ever change the behaviour of this function, to
Expand Down Expand Up @@ -177,7 +177,7 @@ def plot_h1x_gwss_track(
show: gplt_params.show = True,
x_range: Optional[gplt_params.x_range] = None,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> gplt_params.figure:
# Compute H1X.
Expand Down Expand Up @@ -283,7 +283,7 @@ def plot_h1x_gwss(
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
show: gplt_params.show = True,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> gplt_params.figure:
# Plot GWSS track.
Expand Down
2 changes: 1 addition & 1 deletion malariagen_data/anoph/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def pca(
fit_exclude_samples: Optional[base_params.samples] = None,
random_seed: base_params.random_seed = 42,
inline_array: base_params.inline_array = base_params.inline_array_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
) -> Tuple[pca_params.df_pca, pca_params.evr]:
# Change this name if you ever change the behaviour of this function, to
# invalidate any previously cached data.
Expand Down
22 changes: 6 additions & 16 deletions malariagen_data/anoph/snp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DIM_VARIANT,
CacheMiss,
Region,
apply_allele_mapping,
check_types,
da_compress,
da_concat,
Expand Down Expand Up @@ -1629,30 +1630,20 @@ def biallelic_snp_calls(
variant_position = ds_bi["variant_position"].data
coords["variant_position"] = ("variants",), variant_position

# Prepare allele mapping for dask computations.
allele_mapping_zarr = zarr.array(allele_mapping)
allele_mapping_dask = da_from_zarr(
allele_mapping_zarr, chunks="native", inline_array=True
)

# Store alleles, transformed.
variant_allele_dask = ds_bi["variant_allele"].data
variant_allele_out = dask_apply_allele_mapping(
variant_allele_dask, allele_mapping_dask, max_allele=1
variant_allele_dask, allele_mapping, max_allele=1
)
data_vars["variant_allele"] = ("variants", "alleles"), variant_allele_out

# Store allele counts, transformed.
ac_bi_zarr = zarr.array(ac_bi)
ac_bi_dask = da_from_zarr(ac_bi_zarr, chunks="native", inline_array=True)
ac_out = dask_apply_allele_mapping(
ac_bi_dask, allele_mapping_dask, max_allele=1
)
ac_out = apply_allele_mapping(ac_bi, allele_mapping, max_allele=1)
data_vars["variant_allele_count"] = ("variants", "alleles"), ac_out

# Store genotype calls, transformed.
gt_dask = ds_bi["call_genotype"].data
gt_out = dask_genotype_array_map_alleles(gt_dask, allele_mapping_dask)
gt_out = dask_genotype_array_map_alleles(gt_dask, allele_mapping)
data_vars["call_genotype"] = (
(
"variants",
Expand All @@ -1667,9 +1658,8 @@ def biallelic_snp_calls(

# Apply conditions.
if max_missing_an is not None or min_minor_ac is not None:
ac_out_computed = ac_out.compute()
loc_out = np.ones(ds_out.sizes["variants"], dtype=bool)
an = ac_out_computed.sum(axis=1)
an = ac_out.sum(axis=1)

# Apply missingness condition.
if max_missing_an is not None:
Expand All @@ -1683,7 +1673,7 @@ def biallelic_snp_calls(

# Apply minor allele count condition.
if min_minor_ac is not None:
ac_minor = ac_out_computed.min(axis=1)
ac_minor = ac_out.min(axis=1)
if isinstance(min_minor_ac, float):
ac_minor_frac = ac_minor / an
loc_minor = ac_minor_frac >= min_minor_ac
Expand Down
18 changes: 9 additions & 9 deletions malariagen_data/anopheles.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,7 +1622,7 @@ def cohort_diversity_stats(
random_seed: base_params.random_seed = 42,
n_jack: base_params.n_jack = 200,
confidence_level: base_params.confidence_level = 0.95,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> pd.Series:
debug = self._log.debug
Expand Down Expand Up @@ -1728,7 +1728,7 @@ def diversity_stats(
random_seed: base_params.random_seed = 42,
n_jack: base_params.n_jack = 200,
confidence_level: base_params.confidence_level = 0.95,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> pd.DataFrame:
# Normalise cohorts parameter.
Expand Down Expand Up @@ -1933,7 +1933,7 @@ def ihs_gwss(
base_params.max_cohort_size
] = ihs_params.max_cohort_size_default,
random_seed: base_params.random_seed = 42,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> Tuple[np.ndarray, np.ndarray]:
# change this name if you ever change the behaviour of this function, to
Expand Down Expand Up @@ -2110,7 +2110,7 @@ def plot_ihs_gwss_track(
show: gplt_params.show = True,
x_range: Optional[gplt_params.x_range] = None,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> gplt_params.figure:
# compute ihs
Expand Down Expand Up @@ -2251,7 +2251,7 @@ def plot_xpehh_gwss(
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
show: gplt_params.show = True,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> gplt_params.figure:
# gwss track
Expand Down Expand Up @@ -2350,7 +2350,7 @@ def plot_ihs_gwss(
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
show: gplt_params.show = True,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> gplt_params.figure:
# gwss track
Expand Down Expand Up @@ -2445,7 +2445,7 @@ def xpehh_gwss(
base_params.max_cohort_size
] = xpehh_params.max_cohort_size_default,
random_seed: base_params.random_seed = 42,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> Tuple[np.ndarray, np.ndarray]:
# change this name if you ever change the behaviour of this function, to
Expand Down Expand Up @@ -2624,7 +2624,7 @@ def plot_xpehh_gwss_track(
show: gplt_params.show = True,
x_range: Optional[gplt_params.x_range] = None,
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
) -> gplt_params.figure:
# compute xpehh
Expand Down Expand Up @@ -3269,7 +3269,7 @@ def plot_njt(
max_cohort_size: Optional[base_params.max_cohort_size] = None,
random_seed: base_params.random_seed = 42,
inline_array: base_params.inline_array = base_params.inline_array_default,
chunks: base_params.chunks = base_params.large_chunks,
chunks: base_params.chunks = base_params.native_chunks,
) -> plotly_params.figure:
from biotite.sequence.phylo import neighbor_joining # type: ignore
from scipy.spatial.distance import squareform # type: ignore
Expand Down
34 changes: 21 additions & 13 deletions malariagen_data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ def da_from_zarr(
#
# N.B., only resize chunks in arrays with more than one dimension,
# because resizing the one-dimensional arrays according to the same
# size generally leads to poor performance with our datasets.
# size may lead to poor performance with our datasets.
#
# Also, resize along the first dimension only. Again, this is something
# that generally works well for our datasets.
# that may work well for our datasets.
#
# Note that dask also supports this kind of argument, and so we could
# just pass this through. However, some experiments have found this
Expand Down Expand Up @@ -313,8 +313,6 @@ def dask_compress_dataset(ds, indexer, dim):
else:
assert isinstance(indexer, np.ndarray)
indexer_computed = indexer
indexer_zarr = zarr.array(indexer_computed)
indexer = da_from_zarr(indexer_zarr, chunks="native", inline_array=True)

coords = dict()
for k in ds.coords:
Expand Down Expand Up @@ -344,15 +342,22 @@ def _dask_compress_dataarray(a, indexer, indexer_computed, dim):

else:
# apply the indexing operation
v = da_compress(
indexer=indexer, data=a.data, axis=axis, indexer_computed=indexer_computed
)
data = a.data
if isinstance(data, da.Array):
v = da_compress(
indexer=indexer,
data=a.data,
axis=axis,
indexer_computed=indexer_computed,
)
else:
v = np.compress(indexer_computed, data, axis=axis)

return v


def da_compress(
indexer: da.Array,
indexer: da.Array | np.ndarray,
data: da.Array,
axis: int,
indexer_computed: Optional[np.ndarray] = None,
Expand All @@ -373,7 +378,10 @@ def da_compress(
indexer_computed = indexer.compute()

# Ensure indexer and data are chunked in the same way.
indexer = indexer.rechunk((axis_old_chunks,))
if isinstance(indexer, da.Array):
indexer = indexer.rechunk((axis_old_chunks,))
else:
indexer = da.from_array(indexer, chunks=(axis_old_chunks,))

# Apply the indexing operation.
v = da.compress(indexer, data, axis=axis)
Expand Down Expand Up @@ -1490,12 +1498,12 @@ def apply_allele_mapping(x, mapping, max_allele):

def dask_apply_allele_mapping(v, mapping, max_allele):
assert isinstance(v, da.Array)
assert isinstance(mapping, da.Array)
assert isinstance(mapping, np.ndarray)
assert v.ndim == 2
assert mapping.ndim == 2
assert v.shape[0] == mapping.shape[0]
v = v.rechunk((v.chunks[0], -1))
mapping = mapping.rechunk((v.chunks[0], -1))
mapping = da.from_array(mapping, chunks=(v.chunks[0], -1))
out = da.map_blocks(
lambda xb, mb: apply_allele_mapping(xb, mb, max_allele=max_allele),
v,
Expand Down Expand Up @@ -1531,11 +1539,11 @@ def genotype_array_map_alleles(gt, mapping):

def dask_genotype_array_map_alleles(gt, mapping):
assert isinstance(gt, da.Array)
assert isinstance(mapping, da.Array)
assert isinstance(mapping, np.ndarray)
assert gt.ndim == 3
assert mapping.ndim == 2
assert gt.shape[0] == mapping.shape[0]
mapping = mapping.rechunk((gt.chunks[0], -1))
mapping = da.from_array(mapping, chunks=(gt.chunks[0], -1))
gt_out = da.map_blocks(
genotype_array_map_alleles,
gt,
Expand Down
Loading

0 comments on commit b4a3cc9

Please sign in to comment.