Skip to content

Commit

Permalink
Merge pull request #404 from monarch-initiative/add-violinplot-preset
Browse files Browse the repository at this point in the history
Update plot colors, allow plotting violin plot
  • Loading branch information
ielis authored Jan 20, 2025
2 parents ed99d7a + 7fc254a commit 45b5ee1
Show file tree
Hide file tree
Showing 11 changed files with 369 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ based on presence of zero or one *EGFR* mutation allele:
... target=affects_egfr,
... )
>>> gt_clf.class_labels
('0', '1')
('0 alleles', '1 allele')

The ``allele_count`` needs two inputs.
The ``counts`` takes a tuple of the target allele counts,
Expand Down Expand Up @@ -102,9 +102,9 @@ and we will compare the individuals with one allele with those with two alleles:
... target=affects_lmna,
... )
>>> gt_clf.class_labels
('1', '2')
('1 allele', '2 alleles')


The classifier assigns the individuals into one of two classes:
those with one *LMNA* variant allele and those with two *LMNA* variant alleles.
Any cohort member with other allele counts (e.g. `0` or `3`) is ignored.
Any cohort member with other allele counts (e.g. `0 allele` or `3 alleles`) is ignored.
2 changes: 1 addition & 1 deletion docs/user-guide/analyses/survival.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ We can plot Kaplan-Meier curves:
... )
>>> _ = ax.set(
... xlabel=endpoint.name + " [years]",
... ylabel="Empirical survival",
... ylabel="Event-free proportion",
... )
>>> _ = ax.grid(axis="y")

Expand Down
22 changes: 21 additions & 1 deletion src/gpsea/analysis/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,26 @@ def statistic(self) -> Statistic:
"""
return self._statistic

@staticmethod
def _choose_palette_idxs(
n_categories: int,
n_colors: int,
) -> typing.Sequence[int]:
"""
Choose the color indices for coloring `n_categories` using a palette with `n_colors`.
"""
if n_colors < 2:
raise ValueError(
f"Expected a palette with at least 2 colors but got {n_colors}"
)
if n_colors < n_categories:
raise ValueError(
f"The predicate produces {n_categories} categories but the palette includes only {n_colors} colors!"
)

a = np.linspace(start=1, stop=n_colors, num=n_categories, dtype=int)
return tuple(a - 1)

def __eq__(self, value: object) -> bool:
return (
isinstance(value, AnalysisResult)
Expand Down Expand Up @@ -399,7 +419,7 @@ class MonoPhenotypeAnalysisResult(AnalysisResult, metaclass=abc.ABCMeta):
"""
Name of the data index.
"""

GT_COL = "genotype"
"""
Name of column for storing genotype data.
Expand Down
16 changes: 12 additions & 4 deletions src/gpsea/analysis/clf/_gt_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def _build_ac_to_cat(

ac2cat = {}
for i, partition in enumerate(partitions):
name = " OR ".join(str(j) for j in partition)
name = " OR ".join(_pluralize(count=j, base="allele") for j in partition)
description = " OR ".join(labels[j] for j in partition)
cat = Categorization(
PatientCategory(cat_id=i, name=name, description=description),
Expand All @@ -348,6 +348,14 @@ def _build_ac_to_cat(

return ac2cat

def _pluralize(
count: int,
base: str,
) -> str:
if count == 1:
return f"{count} {base}"
else:
return f"{count} {base}s"

def allele_count(
counts: typing.Collection[typing.Union[int, typing.Collection[int]]],
Expand All @@ -372,20 +380,20 @@ def allele_count(
>>> from gpsea.analysis.clf import allele_count
>>> zero_vs_one = allele_count(counts=(0, 1))
>>> zero_vs_one.summarize_classes()
'Allele count: 0, 1'
'Allele count: 0 alleles, 1 allele'
These counts will create three classes for individuals with zero, one or two alleles:
>>> zero_vs_one_vs_two = allele_count(counts=(0, 1, 2))
>>> zero_vs_one_vs_two.summarize_classes()
'Allele count: 0, 1, 2'
'Allele count: 0 alleles, 1 allele, 2 alleles'
Last, the counts below will create two groups, one for the individuals with zero target variant type alleles,
and one for the individuals with one or two alleles:
>>> zero_vs_one_vs_two = allele_count(counts=(0, {1, 2}))
>>> zero_vs_one_vs_two.summarize_classes()
'Allele count: 0, 1 OR 2'
'Allele count: 0 alleles, 1 allele OR 2 alleles'
Note that we wrap the last two allele counts in a set.
Expand Down
22 changes: 11 additions & 11 deletions src/gpsea/analysis/clf/_test__gt_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,32 +61,32 @@ def test_build_count_to_cat(
(
((0,), (1,), (2,)),
{
0: "0",
1: "1",
2: "2",
0: "0 alleles",
1: "1 allele",
2: "2 alleles",
},
),
(
((0, 1), (2,)),
{
0: "0 OR 1",
1: "0 OR 1",
2: "2",
0: "0 alleles OR 1 allele",
1: "0 alleles OR 1 allele",
2: "2 alleles",
},
),
(
((0,), (1, 2)),
{
0: "0",
1: "1 OR 2",
2: "1 OR 2",
0: "0 alleles",
1: "1 allele OR 2 alleles",
2: "1 allele OR 2 alleles",
},
),
(
((1,), (2,)),
{
1: "1",
2: "2",
1: "1 allele",
2: "2 alleles",
},
),
],
Expand Down
158 changes: 130 additions & 28 deletions src/gpsea/analysis/pscore/_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import abc
import math
import typing

import numpy as np
import pandas as pd

from gpsea.model import Patient
from gpsea.config import PALETTE_DATA, PALETTE_SPECIAL
from ..clf import GenotypeClassifier
from .stats import PhenotypeScoreStatistic

Expand Down Expand Up @@ -128,6 +131,16 @@ def __init__(
super().__init__(gt_clf, phenotype, statistic, data, statistic_result)
assert isinstance(phenotype, PhenotypeScorer)

# Check that the provided genotype predicate defines the same categories
# as those found in `data.`
actual = set(
int(val)
for val in data[MonoPhenotypeAnalysisResult.GT_COL].unique()
if val is not None and not math.isnan(val)
)
expected = set(c.cat_id for c in self._gt_clf.get_categories())
assert actual == expected, "Mismatch in the genotype classes"

def phenotype_scorer(self) -> PhenotypeScorer:
"""
Get the scorer that computed the phenotype score.
Expand All @@ -137,31 +150,21 @@ def phenotype_scorer(self) -> PhenotypeScorer:
# being a subclass of `Partitioning`.
return self._phenotype # type: ignore

def plot_boxplots(
def _make_data_df(
self,
ax,
colors=("darksalmon", "honeydew"),
median_color: str = "black",
):
"""
Draw box plot with distributions of phenotype scores for the genotype groups.
:param gt_predicate: the genotype predicate used to produce the genotype groups.
:param ax: the Matplotlib :class:`~matplotlib.axes.Axes` to draw on.
:param colors: a sequence with colors to use for coloring the box patches of the box plot.
:param median_color: a `str` with the color for the boxplot median line.
"""
) -> pd.DataFrame:
# skip the patients with unassigned genotype group
bla = self._data.notna()
not_na_gts = bla.all(axis="columns")
data = self._data.loc[not_na_gts]

# Check that the provided genotype predicate defines the same categories
# as those found in `data.`
actual = set(data[MonoPhenotypeAnalysisResult.GT_COL].unique())
expected = set(c.cat_id for c in self._gt_clf.get_categories())
assert actual == expected, "Mismatch in the genotype classes"
not_na = self._data.notna()
not_na_gts = not_na.all(axis="columns")
return self._data.loc[not_na_gts]

def _make_x_and_tick_labels(
self,
data: pd.DataFrame,
) -> typing.Tuple[
typing.Sequence[typing.Sequence[float]],
typing.Sequence[str],
]:
x = [
data.loc[
data[MonoPhenotypeAnalysisResult.GT_COL] == c.category.cat_id,
Expand All @@ -171,19 +174,116 @@ def plot_boxplots(
]

gt_cat_names = [c.category.name for c in self._gt_clf.get_categorizations()]

return x, gt_cat_names

def plot_boxplots(
self,
ax,
colors: typing.Sequence[str] = PALETTE_DATA,
median_color: str = PALETTE_SPECIAL,
**boxplot_kwargs,
):
"""
Draw box plot with distributions of phenotype scores for the genotype groups.
:param ax: the Matplotlib :class:`~matplotlib.axes.Axes` to draw on.
:param colors: a sequence with color palette for the box plot patches.
:param median_color: a `str` with the color for the boxplot median line.
:param boxplot_kwargs: arguments to pass into :func:`matplotlib.axes.Axes.boxplot` function.
"""
data = self._make_data_df()

x, gt_cat_names = self._make_x_and_tick_labels(data)
patch_artist = boxplot_kwargs.pop("patch_artist", True)
tick_labels = boxplot_kwargs.pop("tick_labels", gt_cat_names)

bplot = ax.boxplot(
x=x,
patch_artist=True,
tick_labels=gt_cat_names,
patch_artist=patch_artist,
tick_labels=tick_labels,
**boxplot_kwargs,
)

# Set face colors of the boxes
for patch, color in zip(bplot["boxes"], colors):
patch.set_facecolor(color)
col_idxs = self._choose_palette_idxs(
n_categories=self._gt_clf.n_categorizations(), n_colors=len(colors)
)
for patch, col_idx in zip(bplot["boxes"], col_idxs):
patch.set_facecolor(colors[col_idx])

for median in bplot['medians']:
for median in bplot["medians"]:
median.set_color(median_color)

def plot_violins(
self,
ax,
colors: typing.Sequence[str] = PALETTE_DATA,
**violinplot_kwargs,
):
"""
Draw a violin plot with distributions of phenotype scores for the genotype groups.
:param ax: the Matplotlib :class:`~matplotlib.axes.Axes` to draw on.
:param colors: a sequence with color palette for the violin patches.
:param violinplot_kwargs: arguments to pass into :func:`matplotlib.axes.Axes.violinplot` function.
"""
data = self._make_data_df()

x, gt_cat_names = self._make_x_and_tick_labels(data)

showmeans = violinplot_kwargs.pop("showmeans", False)
showextrema = violinplot_kwargs.pop("showextrema", False)

parts = ax.violinplot(
dataset=x,
showmeans=showmeans,
showextrema=showextrema,
**violinplot_kwargs,
)

# quartile1, medians, quartile3 = np.percentile(x, [25, 50, 75], axis=1)
quartile1 = [np.percentile(v, 25) for v in x]
medians = [np.median(v) for v in x]
quartile3 = [np.percentile(v, 75) for v in x]
x = [sorted(val) for val in x]
whiskers = np.array(
[
PhenotypeScoreAnalysisResult._adjacent_values(sorted_array, q1, q3)
for sorted_array, q1, q3 in zip(x, quartile1, quartile3)
]
)
whiskers_min, whiskers_max = whiskers[:, 0], whiskers[:, 1]

inds = np.arange(1, len(medians) + 1)
ax.scatter(inds, medians, marker="o", color="white", s=30, zorder=3)
ax.vlines(inds, quartile1, quartile3, color="k", linestyle="-", lw=5)
ax.vlines(inds, whiskers_min, whiskers_max, color="k", linestyle="-", lw=1)

ax.xaxis.set(
ticks=np.arange(1, len(gt_cat_names) + 1),
ticklabels=gt_cat_names,
)

col_idxs = self._choose_palette_idxs(
n_categories=self._gt_clf.n_categorizations(), n_colors=len(colors)
)
for pc, color_idx in zip(parts["bodies"], col_idxs):
pc.set(
facecolor=colors[color_idx],
edgecolor=None,
alpha=1,
)

@staticmethod
def _adjacent_values(vals, q1, q3):
upper_adjacent_value = q3 + (q3 - q1) * 1.5
upper_adjacent_value = np.clip(upper_adjacent_value, q3, vals[-1])

lower_adjacent_value = q1 - (q3 - q1) * 1.5
lower_adjacent_value = np.clip(lower_adjacent_value, vals[0], q1)
return lower_adjacent_value, upper_adjacent_value

def __eq__(self, value: object) -> bool:
return isinstance(value, PhenotypeScoreAnalysisResult) and super(
MonoPhenotypeAnalysisResult, self
Expand Down Expand Up @@ -254,7 +354,9 @@ def compare_genotype_vs_phenotype_score(
for individual in cohort:
gt_cat = gt_clf.test(individual)
if gt_cat is None:
data.loc[individual.patient_id, MonoPhenotypeAnalysisResult.GT_COL] = None
data.loc[individual.patient_id, MonoPhenotypeAnalysisResult.GT_COL] = (
None
)
else:
data.loc[individual.patient_id, MonoPhenotypeAnalysisResult.GT_COL] = (
gt_cat.category.cat_id
Expand Down
Loading

0 comments on commit 45b5ee1

Please sign in to comment.