Skip to content

Commit

Permalink
Refactor SNP frequencies (#490)
Browse files Browse the repository at this point in the history
* wip refactor snp frequencies

* wip refactor snp frequencies

* wip

* wip refactor

* fixes

* wip migrate tests

* wip migrate tests

* wip migrate tests; allow any sample metadata column as cohorts

* ruff fix

* wip migrate tests

* refactor tests

* wip tests

* fix tests

* wip migrate tests

* finish migrating snp frequency tests; fix nan frequency bug

* minor tweaks

* plot tests

* relax

* improve coverage

* more testing

* coverage

* get test durations

* duh

* relax more

* typing
  • Loading branch information
alimanfoo authored Jan 3, 2024
1 parent 7995a0f commit f4c263f
Show file tree
Hide file tree
Showing 21 changed files with 3,072 additions and 3,130 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: poetry install

- name: Run tests with coverage
run: poetry run pytest -v --cov malariagen_data/anoph --cov-report=xml tests/anoph
run: poetry run pytest --durations=20 --durations-min=1.0 -v --cov malariagen_data/anoph --cov-report=xml tests/anoph

- name: Upload coverage report
uses: codecov/codecov-action@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
key: gcs_cache_tests_20231119

- name: Run full test suite
run: poetry run pytest -v tests
run: poetry run pytest --durations=20 --durations-min=10.0 -v tests

- name: Save GCS cache
uses: actions/cache/save@v3
Expand Down
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@ repos:
hooks:
# Run the linter.
- id: ruff
args:
- "--fix"
# Run the formatter.
- id: ruff-format
19 changes: 1 addition & 18 deletions malariagen_data/af1.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
tqdm_class=tqdm_class,
taxon_colors=TAXON_COLORS,
virtual_contigs=None,
gene_names=None,
)

def __repr__(self):
Expand Down Expand Up @@ -209,21 +210,3 @@ def _repr_html_(self):
</table>
"""
return html

def _transcript_to_gene_name(self, transcript):
df_genome_features = self.genome_features().set_index("ID")
rec_transcript = df_genome_features.loc[transcript]
parent = rec_transcript["Parent"]

# E.g. manual overrides (used in Ag3)
# if parent == "AGAP004707":
# parent_name = "Vgsc/para"
# else:
# parent_name = rec_parent["Name"]

# Note: Af1 doesn't have the "Name" attribute
# rec_parent = df_genome_features.loc[parent]
# parent_name = rec_parent["Name"]
parent_name = parent

return parent_name
18 changes: 4 additions & 14 deletions malariagen_data/ag3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
"3RL": ("3R", "3L"),
"23X": ("2R", "2L", "3R", "3L", "X"),
}
GENE_NAMES = {
"AGAP004707": "Vgsc/para",
}


def _setup_aim_palettes():
Expand Down Expand Up @@ -191,6 +194,7 @@ def __init__(
tqdm_class=tqdm_class,
taxon_colors=TAXON_COLORS,
virtual_contigs=VIRTUAL_CONTIGS,
gene_names=GENE_NAMES,
)

# set up caches
Expand Down Expand Up @@ -293,20 +297,6 @@ def _repr_html_(self):
"""
return html

def _transcript_to_gene_name(self, transcript):
df_genome_features = self.genome_features().set_index("ID")
rec_transcript = df_genome_features.loc[transcript]
parent = rec_transcript["Parent"]
rec_parent = df_genome_features.loc[parent]

# manual overrides
if parent == "AGAP004707":
parent_name = "Vgsc/para"
else:
parent_name = rec_parent["Name"]

return parent_name

def cross_metadata(self):
"""Load a dataframe containing metadata about samples in colony crosses,
including which samples are parents or progeny in which crosses.
Expand Down
101 changes: 51 additions & 50 deletions malariagen_data/anoph/cnv_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,60 +163,61 @@ def cnv_hmm(
regions: List[Region] = parse_multi_region(self, region)
del region

debug("access CNV HMM data and concatenate as needed")
lx = []
for r in regions:
ly = []
for s in sample_sets:
y = self._cnv_hmm_dataset(
contig=r.contig,
sample_set=s,
inline_array=inline_array,
chunks=chunks,
with self._spinner("Access CNV HMM data"):
debug("access CNV HMM data and concatenate as needed")
lx = []
for r in regions:
ly = []
for s in sample_sets:
y = self._cnv_hmm_dataset(
contig=r.contig,
sample_set=s,
inline_array=inline_array,
chunks=chunks,
)
ly.append(y)

debug("concatenate data from multiple sample sets")
x = simple_xarray_concat(ly, dim=DIM_SAMPLE)

debug("handle region, do this only once - optimisation")
if r.start is not None or r.end is not None:
start = x["variant_position"].values
end = x["variant_end"].values
index = pd.IntervalIndex.from_arrays(start, end, closed="both")
# noinspection PyArgumentList
other = pd.Interval(r.start, r.end, closed="both")
loc_region = index.overlaps(other) # type: ignore
x = x.isel(variants=loc_region)

lx.append(x)

debug("concatenate data from multiple regions")
ds = simple_xarray_concat(lx, dim=DIM_VARIANT)

debug("handle sample query")
if sample_query is not None:
debug("load sample metadata")
df_samples = self.sample_metadata(sample_sets=sample_sets)

debug("align sample metadata with CNV data")
cnv_samples = ds["sample_id"].values.tolist()
df_samples_cnv = (
df_samples.set_index("sample_id").loc[cnv_samples].reset_index()
)
ly.append(y)

debug("concatenate data from multiple sample sets")
x = simple_xarray_concat(ly, dim=DIM_SAMPLE)

debug("handle region, do this only once - optimisation")
if r.start is not None or r.end is not None:
start = x["variant_position"].values
end = x["variant_end"].values
index = pd.IntervalIndex.from_arrays(start, end, closed="both")
# noinspection PyArgumentList
other = pd.Interval(r.start, r.end, closed="both")
loc_region = index.overlaps(other) # type: ignore
x = x.isel(variants=loc_region)

lx.append(x)

debug("concatenate data from multiple regions")
ds = simple_xarray_concat(lx, dim=DIM_VARIANT)

debug("handle sample query")
if sample_query is not None:
debug("load sample metadata")
df_samples = self.sample_metadata(sample_sets=sample_sets)

debug("align sample metadata with CNV data")
cnv_samples = ds["sample_id"].values.tolist()
df_samples_cnv = (
df_samples.set_index("sample_id").loc[cnv_samples].reset_index()
)

debug("apply the query")
loc_query_samples = df_samples_cnv.eval(sample_query).values
if np.count_nonzero(loc_query_samples) == 0:
raise ValueError(f"No samples found for query {sample_query!r}")
debug("apply the query")
loc_query_samples = df_samples_cnv.eval(sample_query).values
if np.count_nonzero(loc_query_samples) == 0:
raise ValueError(f"No samples found for query {sample_query!r}")

ds = ds.isel(samples=loc_query_samples)
ds = ds.isel(samples=loc_query_samples)

debug("handle coverage variance filter")
if max_coverage_variance is not None:
cov_var = ds["sample_coverage_variance"].values
loc_pass_samples = cov_var <= max_coverage_variance
ds = ds.isel(samples=loc_pass_samples)
debug("handle coverage variance filter")
if max_coverage_variance is not None:
cov_var = ds["sample_coverage_variance"].values
loc_pass_samples = cov_var <= max_coverage_variance
ds = ds.isel(samples=loc_pass_samples)

return ds

Expand Down
5 changes: 5 additions & 0 deletions malariagen_data/anoph/frq_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,8 @@
`gene_cnv_frequencies_advanced()`.
""",
]

include_counts: TypeAlias = Annotated[
bool,
"Include columns with allele counts and number of non-missing allele calls (nobs).",
]
28 changes: 26 additions & 2 deletions malariagen_data/anoph/genome_features.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Mapping

import bokeh.models
import bokeh.plotting
Expand Down Expand Up @@ -26,6 +26,7 @@ def __init__(
*,
gff_gene_type: str,
gff_default_attributes: Tuple[str, ...],
gene_names: Optional[Mapping[str, str]] = None,
**kwargs,
):
# N.B., this class is designed to work cooperatively, and
Expand All @@ -38,14 +39,19 @@ def __init__(
self._gff_gene_type = gff_gene_type
self._gff_default_attributes = gff_default_attributes

# Allow manual override of gene names.
if gene_names is None:
gene_names = dict()
self._gene_name_overrides = gene_names

# Setup caches.
self._cache_genome_features: Dict[Tuple[str, ...], pd.DataFrame] = dict()

@property
def _geneset_gff3_path(self):
return self.config["GENESET_GFF3_PATH"]

def geneset(self, *args, **kwargs):
def geneset(self, *args, **kwargs): # pragma: no cover
"""Deprecated, this method has been renamed to genome_features()."""
return self.genome_features(*args, **kwargs)

Expand Down Expand Up @@ -429,3 +435,21 @@ def _bokeh_style_genome_xaxis(fig, contig):
fig.xaxis.ticker = bokeh.models.AdaptiveTicker(min_interval=1)
fig.xaxis.minor_tick_line_color = None
fig.xaxis[0].formatter = bokeh.models.NumeralTickFormatter(format="0,0")

def _transcript_to_parent_name(self, transcript):
df_genome_features = self.genome_features().set_index("ID")

try:
rec_transcript = df_genome_features.loc[transcript]
except KeyError:
return None

parent_id = rec_transcript["Parent"]

try:
# Manual override.
return self._gene_name_overrides[parent_id]
except KeyError:
rec_parent = df_genome_features.loc[parent_id]
# Try to access "Name" attribute, fall back to "ID" if not present.
return rec_parent.get("Name", parent_id)
34 changes: 34 additions & 0 deletions malariagen_data/anoph/sample_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,3 +969,37 @@ def _setup_sample_hover_data_plotly(
if symbol and symbol not in hover_data:
hover_data.append(symbol)
return hover_data


def locate_cohorts(*, cohorts, data):
# Build cohort dictionary where key=cohort_id, value=loc_coh.
coh_dict = {}

if isinstance(cohorts, Mapping):
# User has supplied a custom dictionary mapping cohort identifiers
# to pandas queries.

for coh, query in cohorts.items():
loc_coh = data.eval(query).values
coh_dict[coh] = loc_coh

else:
assert isinstance(cohorts, str)
# User has supplied the name of a sample metadata column.

# Convenience to allow things like "admin1_year" instead of "cohort_admin1_year".
if "cohort_" + cohorts in data.columns:
cohorts = "cohort_" + cohorts

# Check the given cohort set exists.
if cohorts not in data.columns:
raise ValueError(f"{cohorts!r} is not a known column in the data.")
cohort_labels = data[cohorts].unique()

# Remove the nans and sort.
cohort_labels = sorted([c for c in cohort_labels if isinstance(c, str)])
for coh in cohort_labels:
loc_coh = data[cohorts] == coh
coh_dict[coh] = loc_coh.values

return coh_dict
Loading

0 comments on commit f4c263f

Please sign in to comment.