Skip to content

Commit

Permalink
Misc bug fixes (#480)
Browse files Browse the repository at this point in the history
* use Dataset.sizes[] instead of Dataset.dims[]

* fix param typing

* allow taxon order override

* fix param handling

* more params

* more progress

* relax typing
  • Loading branch information
alimanfoo authored Dec 11, 2023
1 parent 8843cf5 commit b797476
Show file tree
Hide file tree
Showing 18 changed files with 261 additions and 142 deletions.
4 changes: 2 additions & 2 deletions malariagen_data/anoph/hap_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ def haplotypes(

if min_cohort_size is not None:
# Handle min cohort size.
n_samples = ds.dims["samples"]
n_samples = ds.sizes["samples"]
if n_samples < min_cohort_size:
raise ValueError(
f"Not enough samples ({n_samples}) for minimum cohort size ({min_cohort_size})"
)

if max_cohort_size is not None:
# Handle max cohort size.
n_samples = ds.dims["samples"]
n_samples = ds.sizes["samples"]
if n_samples > max_cohort_size:
rng = np.random.default_rng(seed=random_seed)
loc_downsample = rng.choice(
Expand Down
15 changes: 12 additions & 3 deletions malariagen_data/anoph/plotly_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
]

category_order: TypeAlias = Annotated[
Optional[List],
Optional[Union[List, Mapping]],
"Control the order in which values appear in the legend.",
]

Expand Down Expand Up @@ -97,12 +97,12 @@
]

color: TypeAlias = Annotated[
Optional[str],
Optional[Union[str, Mapping]],
"Name of variable to use to color the markers.",
]

symbol: TypeAlias = Annotated[
Optional[str],
Optional[Union[str, Mapping]],
"Name of the variable to use to choose marker symbols.",
]

Expand Down Expand Up @@ -166,3 +166,12 @@
Union[int, float],
"The upper end of the range of values that the colormap covers.",
]

legend_sizing: TypeAlias = Annotated[
Literal["constant", "trace"],
"""
Determines if the legend items symbols scale with their corresponding
"trace" attributes or remain "constant" independent of the symbol size
on the graph.
""",
]
6 changes: 3 additions & 3 deletions malariagen_data/anoph/sample_metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import io
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import ipyleaflet
import numpy as np
Expand Down Expand Up @@ -508,14 +508,14 @@ def count_samples(
self,
sample_sets: Optional[base_params.sample_sets] = None,
sample_query: Optional[base_params.sample_query] = None,
index: Union[str, Tuple[str, ...]] = (
index: Union[str, Sequence[str]] = (
"country",
"admin1_iso",
"admin1_name",
"admin2_name",
"year",
),
columns: Union[str, Tuple[str, ...]] = "taxon",
columns: Union[str, Sequence[str]] = "taxon",
) -> pd.DataFrame:
# Load sample metadata.
df_samples = self.sample_metadata(
Expand Down
135 changes: 70 additions & 65 deletions malariagen_data/anoph/snp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,15 +925,15 @@ def _snp_calls(

# Handle min cohort size.
if min_cohort_size is not None:
n_samples = ds.dims["samples"]
n_samples = ds.sizes["samples"]
if n_samples < min_cohort_size:
raise ValueError(
f"not enough samples ({n_samples}) for minimum cohort size ({min_cohort_size})"
)

# Handle max cohort size.
if max_cohort_size is not None:
n_samples = ds.dims["samples"]
n_samples = ds.sizes["samples"]
if n_samples > max_cohort_size:
rng = np.random.default_rng(seed=random_seed)
loc_downsample = rng.choice(
Expand Down Expand Up @@ -1439,74 +1439,79 @@ def biallelic_snp_calls(
chunks=chunks,
)

# Subset to biallelic sites.
ds_bi = ds.isel(variants=loc_bi)
with self._spinner("Prepare biallelic SNP calls"):
# Subset to biallelic sites.
ds_bi = ds.isel(variants=loc_bi)

# Start building a new dataset.
coords: Dict[str, Any] = dict()
data_vars: Dict[str, Any] = dict()
# Start building a new dataset.
coords: Dict[str, Any] = dict()
data_vars: Dict[str, Any] = dict()

# Store sample IDs.
coords["sample_id"] = ("samples",), ds_bi["sample_id"].data
# Store sample IDs.
coords["sample_id"] = ("samples",), ds_bi["sample_id"].data

# Store contig.
coords["variant_contig"] = ("variants",), ds_bi["variant_contig"].data
# Store contig.
coords["variant_contig"] = ("variants",), ds_bi["variant_contig"].data

# Store position.
coords["variant_position"] = ("variants",), ds_bi["variant_position"].data
# Store position.
coords["variant_position"] = ("variants",), ds_bi["variant_position"].data

# Store alleles, transformed.
variant_allele = ds_bi["variant_allele"].data
variant_allele = variant_allele.rechunk((variant_allele.chunks[0], -1))
variant_allele_out = da.map_blocks(
lambda block: apply_allele_mapping(block, allele_mapping, max_allele=1),
variant_allele,
dtype=variant_allele.dtype,
chunks=(variant_allele.chunks[0], [2]),
)
data_vars["variant_allele"] = ("variants", "alleles"), variant_allele_out

# Store allele counts, transformed, so we don't have to recompute.
ac_out = apply_allele_mapping(ac_bi, allele_mapping, max_allele=1)
data_vars["variant_allele_count"] = ("variants", "alleles"), ac_out

# Store genotype calls, transformed.
gt = ds_bi["call_genotype"].data
gt_out = allel.GenotypeDaskArray(gt).map_alleles(allele_mapping)
data_vars["call_genotype"] = ("variants", "samples", "ploidy"), gt_out.values

# Build dataset.
ds_out = xr.Dataset(coords=coords, data_vars=data_vars, attrs=ds.attrs)

# Apply conditions.
if max_missing_an or min_minor_ac:
loc_out = np.ones(ds_out.dims["variants"], dtype=bool)

# Apply missingness condition.
if max_missing_an is not None:
an = ac_out.sum(axis=1)
an_missing = (ds_out.dims["samples"] * ds_out.dims["ploidy"]) - an
loc_missing = an_missing <= max_missing_an
loc_out &= loc_missing

# Apply minor allele count condition.
if min_minor_ac is not None:
ac_minor = ac_out.min(axis=1)
loc_minor = ac_minor >= min_minor_ac
loc_out &= loc_minor

ds_out = ds_out.isel(variants=loc_out)

# Try to meet target number of SNPs.
if n_snps is not None:
if ds_out.dims["variants"] > (n_snps * 2):
# Do some thinning.
thin_step = ds_out.dims["variants"] // n_snps
loc_thin = slice(thin_offset, None, thin_step)
ds_out = ds_out.isel(variants=loc_thin)

elif ds_out.dims["variants"] < n_snps:
raise ValueError("Not enough SNPs.")
# Store alleles, transformed.
variant_allele = ds_bi["variant_allele"].data
variant_allele = variant_allele.rechunk((variant_allele.chunks[0], -1))
variant_allele_out = da.map_blocks(
lambda block: apply_allele_mapping(block, allele_mapping, max_allele=1),
variant_allele,
dtype=variant_allele.dtype,
chunks=(variant_allele.chunks[0], [2]),
)
data_vars["variant_allele"] = ("variants", "alleles"), variant_allele_out

# Store allele counts, transformed, so we don't have to recompute.
ac_out = apply_allele_mapping(ac_bi, allele_mapping, max_allele=1)
data_vars["variant_allele_count"] = ("variants", "alleles"), ac_out

# Store genotype calls, transformed.
gt = ds_bi["call_genotype"].data
gt_out = allel.GenotypeDaskArray(gt).map_alleles(allele_mapping)
data_vars["call_genotype"] = (
"variants",
"samples",
"ploidy",
), gt_out.values

# Build dataset.
ds_out = xr.Dataset(coords=coords, data_vars=data_vars, attrs=ds.attrs)

# Apply conditions.
if max_missing_an or min_minor_ac:
loc_out = np.ones(ds_out.sizes["variants"], dtype=bool)

# Apply missingness condition.
if max_missing_an is not None:
an = ac_out.sum(axis=1)
an_missing = (ds_out.sizes["samples"] * ds_out.sizes["ploidy"]) - an
loc_missing = an_missing <= max_missing_an
loc_out &= loc_missing

# Apply minor allele count condition.
if min_minor_ac is not None:
ac_minor = ac_out.min(axis=1)
loc_minor = ac_minor >= min_minor_ac
loc_out &= loc_minor

ds_out = ds_out.isel(variants=loc_out)

# Try to meet target number of SNPs.
if n_snps is not None:
if ds_out.sizes["variants"] > (n_snps * 2):
# Do some thinning.
thin_step = ds_out.sizes["variants"] // n_snps
loc_thin = slice(thin_offset, None, thin_step)
ds_out = ds_out.isel(variants=loc_thin)

elif ds_out.sizes["variants"] < n_snps:
raise ValueError("Not enough SNPs.")

return ds_out

Expand Down
35 changes: 26 additions & 9 deletions malariagen_data/anopheles.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def _map_snp_to_aa_change_frq_ds(ds):
"event_nobs",
]

if ds.dims["variants"] == 1:
if ds.sizes["variants"] == 1:
# keep everything as-is, no need for aggregation
ds_out = ds[keep_vars + ["variant_alt_allele", "event_count"]]

Expand Down Expand Up @@ -2311,7 +2311,7 @@ def _gene_cnv_frequencies(
debug(
"setup output dataframe - two rows for each gene, one for amplification and one for deletion"
)
n_genes = ds_cnv.dims["genes"]
n_genes = ds_cnv.sizes["genes"]
df_genes = ds_cnv[
[
"gene_id",
Expand Down Expand Up @@ -2569,7 +2569,7 @@ def _gene_cnv_frequencies_advanced(
is_called = cn >= 0

debug("set up main event variables")
n_genes = ds_cnv.dims["genes"]
n_genes = ds_cnv.sizes["genes"]
n_variants, n_cohorts = n_genes * 2, len(df_cohorts)
count = np.zeros((n_variants, n_cohorts), dtype=int)
nobs = np.zeros((n_variants, n_cohorts), dtype=int)
Expand Down Expand Up @@ -3545,6 +3545,7 @@ def plot_frequencies_time_series(
height: plotly_params.height = None,
width: plotly_params.width = None,
title: plotly_params.title = True,
legend_sizing: plotly_params.legend_sizing = "constant",
show: plotly_params.show = True,
renderer: plotly_params.renderer = None,
**kwargs,
Expand Down Expand Up @@ -3634,7 +3635,10 @@ def plot_frequencies_time_series(
)

debug("tidy plot")
fig.update_layout(yaxis_range=[-0.05, 1.05])
fig.update_layout(
yaxis_range=[-0.05, 1.05],
legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
)

if show: # pragma: no cover
fig.show(renderer=renderer)
Expand Down Expand Up @@ -3833,11 +3837,15 @@ def plot_pca_coords(
color_discrete_sequence: plotly_params.color_discrete_sequence = None,
color_discrete_map: plotly_params.color_discrete_map = None,
category_orders: plotly_params.category_order = None,
legend_sizing: plotly_params.legend_sizing = "constant",
show: plotly_params.show = True,
renderer: plotly_params.renderer = None,
render_mode: plotly_params.render_mode = "svg",
**kwargs,
) -> plotly_params.figure:
# Copy input data to avoid overwriting.
data = data.copy()

# Apply jitter if desired - helps spread out points when tightly clustered.
if jitter_frac:
np.random.seed(random_seed)
Expand Down Expand Up @@ -3894,7 +3902,7 @@ def plot_pca_coords(

# Tidy up.
fig.update_layout(
legend=dict(itemsizing="constant"),
legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
)
fig.update_traces(marker={"size": marker_size})

Expand Down Expand Up @@ -3930,10 +3938,14 @@ def plot_pca_coords_3d(
color_discrete_sequence: plotly_params.color_discrete_sequence = None,
color_discrete_map: plotly_params.color_discrete_map = None,
category_orders: plotly_params.category_order = None,
legend_sizing: plotly_params.legend_sizing = "constant",
show: plotly_params.show = True,
renderer: plotly_params.renderer = None,
**kwargs,
) -> plotly_params.figure:
# Copy input data to avoid overwriting.
data = data.copy()

# Apply jitter if desired - helps spread out points when tightly clustered.
if jitter_frac:
np.random.seed(random_seed)
Expand Down Expand Up @@ -3989,7 +4001,7 @@ def plot_pca_coords_3d(
# Tidy up.
fig.update_layout(
scene=dict(aspectmode="cube"),
legend=dict(itemsizing="constant"),
legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
)
fig.update_traces(marker={"size": marker_size})

Expand Down Expand Up @@ -6486,6 +6498,7 @@ def plot_haplotype_clustering(
color_discrete_sequence: plotly_params.color_discrete_sequence = None,
color_discrete_map: plotly_params.color_discrete_map = None,
category_orders: plotly_params.category_order = None,
legend_sizing: plotly_params.legend_sizing = "constant",
) -> plotly_params.figure:
import sys

Expand Down Expand Up @@ -6588,6 +6601,7 @@ def plot_haplotype_clustering(
title_font=dict(
size=title_font_size,
),
legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
)

if show: # pragma: no cover
Expand Down Expand Up @@ -7120,7 +7134,7 @@ def plot_njt(
category_orders: plotly_params.category_order = None,
edge_legend: bool = False,
leaf_legend: bool = True,
legend_sizing: str = "trace",
legend_sizing: plotly_params.legend_sizing = "constant",
thin_offset: base_params.thin_offset = 0,
sample_sets: Optional[base_params.sample_sets] = None,
sample_query: Optional[base_params.sample_query] = None,
Expand Down Expand Up @@ -7291,7 +7305,7 @@ def plot_njt(
title_font=dict(
size=title_font_size,
),
legend=dict(itemsizing=legend_sizing),
legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
)

# Style axes.
Expand Down Expand Up @@ -7370,7 +7384,10 @@ def _setup_plotly_sample_colors(
# Special case, default taxon colors and order.
color_params = self._setup_taxon_colors()
color_discrete_map_prepped = color_params["color_discrete_map"]
category_orders_prepped = color_params["category_orders"]
if category_orders is None:
category_orders_prepped = color_params["category_orders"]
else:
category_orders_prepped = category_orders
color_prepped = color
# Bail out early.
return (
Expand Down
4 changes: 2 additions & 2 deletions malariagen_data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def dask_compress_dataset(ds, indexer, dim):
assert isinstance(indexer, da.Array)
assert indexer.ndim == 1
assert indexer.dtype == bool
assert indexer.shape[0] == ds.dims[dim]
assert indexer.shape[0] == ds.sizes[dim]

# temporarily compute the indexer once, to avoid multiple reads from
# the underlying data
Expand Down Expand Up @@ -571,7 +571,7 @@ def _simple_xarray_concat_arrays(

# Iterate over variable names.
for k in names:
# Access the variable from the virst dataset.
# Access the variable from the first dataset.
v = ds0[k]

if dim in v.dims:
Expand Down
Loading

0 comments on commit b797476

Please sign in to comment.