-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add compat file for different seaborn versions
compatibility with seaborn>=0.11 pass all tests
- Loading branch information
Showing
7 changed files
with
1,114 additions
and
289 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
numpy>=1.12.1 | ||
seaborn>=0.9.0,<0.12 | ||
seaborn>=0.9.0 | ||
matplotlib>=2.2.2 | ||
pandas>=0.23.0,<2.0.0 | ||
pandas>=0.23.0 | ||
scipy>=1.1.0 | ||
statsmodels | ||
packaging |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,160 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Iterator, Sequence | ||
import itertools | ||
from typing import TYPE_CHECKING | ||
import warnings | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
if TYPE_CHECKING: | ||
from .compat import TupleGroup, TGroupValue, THueValue | ||
|
||
|
||
def get_group_names_and_labels( | ||
group_names: Sequence[TGroupValue], | ||
hue_names: Sequence[THueValue], | ||
) -> tuple[list[TupleGroup], list[str]]: | ||
tuple_group_names: list[TupleGroup] | ||
if len(hue_names) == 0: | ||
tuple_group_names = [(name,) for name in group_names] | ||
labels = [str(name) for name in group_names] | ||
|
||
from statannotations.utils import get_closest | ||
else: | ||
labels = [] | ||
tuple_group_names = [] | ||
for group_name, hue_name in itertools.product(group_names, hue_names): | ||
tuple_group_names.append((group_name, hue_name)) | ||
labels.append(f'{group_name}_{hue_name}') | ||
|
||
return tuple_group_names, labels | ||
|
||
|
||
class _GroupsPositions: | ||
def __init__(self, plotter, group_names): | ||
self._plotter = plotter | ||
self._hue_names = self._plotter.hue_names | ||
POSITION_TOLERANCE: float = 0.1 | ||
|
||
if self._hue_names is not None: | ||
nb_hues = len(self._hue_names) | ||
if nb_hues == 1: | ||
raise ValueError( | ||
"Using hues with only one hue is not supported.") | ||
width: float | ||
tuple_group_names: list[TupleGroup] | ||
labels: list[str] | ||
_data: pd.DataFrame | ||
|
||
self.hue_offsets = self._plotter.hue_offsets | ||
self._axis_units = self.hue_offsets[1] - self.hue_offsets[0] | ||
def __init__( | ||
self, | ||
group_names: Sequence[TGroupValue], | ||
hue_names: Sequence[THueValue], | ||
*, | ||
dodge: bool = True, | ||
gap: float = 0.0, | ||
width: float = 0.8, | ||
native_group_offsets: Sequence | None = None, | ||
) -> None: | ||
self.gap = gap | ||
self.dodge = dodge | ||
|
||
self._groups_positions = { | ||
np.round(self.get_group_axis_position(group_name), 1): group_name | ||
for group_name in group_names | ||
} | ||
self._group_names = group_names | ||
self._hue_names = hue_names | ||
self.use_hue = len(hue_names) == 0 | ||
|
||
self._groups_positions_list = sorted(self._groups_positions.keys()) | ||
# Compute the coordinates of the groups (without hue) and the width | ||
self.group_offsets, self.width = self._set_group_offsets( | ||
group_names, native_group_offsets, width | ||
) | ||
# Create the tuple (group, hue) and the labels | ||
self.tuple_group_names, self.labels = get_group_names_and_labels(group_names, hue_names) | ||
|
||
if self._hue_names is None: | ||
self._axis_units = ((max(list(self._groups_positions.keys())) + 1) | ||
/ len(self._groups_positions)) | ||
# Create dataframe with the groups, labels and positions | ||
# this should be done last, when the other attributes are defined | ||
self._data, self._artist_width = self._set_data(dodge=dodge, gap=gap) | ||
|
||
self._axis_ranges = { | ||
(pos - self._axis_units / 2, | ||
pos + self._axis_units / 2, | ||
pos): group_name | ||
for pos, group_name in self._groups_positions.items()} | ||
def _set_group_offsets( | ||
self, | ||
group_names: Sequence, | ||
native_group_offsets: Sequence | None, | ||
width: float, | ||
) -> tuple[Sequence, float]: | ||
"""Set the group offsets from native scale and scale the width.""" | ||
group_offsets = list(range(len(group_names))) | ||
if native_group_offsets is not None: | ||
curated_offsets = [v for v in native_group_offsets] | ||
if len(curated_offsets) != len(group_names): | ||
msg = ( | ||
'The values of the categories with "native_scale=True" do not correspond ' | ||
'to the category names. Maybe some values are not finite?' | ||
) | ||
warnings.warn(msg) | ||
else: | ||
group_offsets = curated_offsets | ||
if len(curated_offsets) > 1: | ||
native_width = np.min(np.diff(curated_offsets)) | ||
width *= native_width | ||
|
||
@property | ||
def axis_positions(self): | ||
return self._groups_positions | ||
return group_offsets, width | ||
|
||
@property | ||
def axis_units(self): | ||
return self._axis_units | ||
def _set_data(self, dodge: bool, gap: float) -> tuple[pd.DataFrame, float]: | ||
n_repeat = max(len(self._hue_names), 1) | ||
group_positions = np.array(self.group_offsets) | ||
positions = np.repeat(group_positions, n_repeat) | ||
artist_width = float(self.width) | ||
data = pd.DataFrame( | ||
{ | ||
'group': self.tuple_group_names, | ||
'label': self.labels, | ||
'pos': positions, | ||
}, | ||
) | ||
if dodge and self.use_hue: | ||
n_hues = max(len(self._hue_names), 1) | ||
artist_width /= n_hues | ||
# evenly space range centered in zero (subtracting the mean) | ||
offset = artist_width * (np.arange(n_hues) - (n_hues - 1) / 2) | ||
tiled_offset = np.tile(offset, len(self._group_names)) | ||
data['pos'] += tiled_offset | ||
if gap and gap >= 0 and gap <= 1: | ||
artist_width *= 1 - gap | ||
|
||
def get_axis_pos_location(self, pos): | ||
""" | ||
Finds the x-axis location of a categorical variable | ||
""" | ||
for axis_range in self._axis_ranges: | ||
if (pos >= axis_range[0]) & (pos <= axis_range[1]): | ||
return axis_range[2] | ||
return data, artist_width | ||
|
||
def get_group_axis_position(self, group): | ||
""" | ||
group_name can be either a name "cat" or a tuple ("cat", "hue") | ||
def find_group_at_pos(self, pos: float, *, verbose: bool = False) -> TupleGroup | None: | ||
positions = self._data['pos'] | ||
if len(positions) == 0: | ||
return None | ||
# Get the index of the closest position | ||
index = (positions - pos).abs().idxmin() | ||
found_pos = positions.loc[index] | ||
|
||
if verbose and abs(found_pos - pos) > self.POSITION_TOLERANCE: | ||
# The requested position is not an artist position | ||
msg = ( | ||
'Invalid x-position found. Are the same parameters passed to ' | ||
'seaborn and statannotations calls? Or are there few data points? ' | ||
f'The closest group position to {pos} is {found_pos}' | ||
) | ||
warnings.warn(msg) | ||
return self._data.loc[index, 'group'] | ||
|
||
def get_group_axis_position(self, group: TupleGroup) -> float: | ||
"""Get the position of the group. | ||
group_name can be either a tuple ("group",) or a tuple ("group", "hue") | ||
""" | ||
if self._plotter.plot_hues is None: | ||
cat = group | ||
hue_offset = 0 | ||
else: | ||
cat = group[0] | ||
hue_level = group[1] | ||
hue_offset = self._plotter.hue_offsets[ | ||
self._plotter.hue_names.index(hue_level)] | ||
|
||
group_pos = self._plotter.group_names.index(cat) + hue_offset | ||
return group_pos | ||
|
||
def find_closest(self, pos): | ||
return get_closest(list(self._groups_positions_list), pos) | ||
group_names = self._data['group'] | ||
if group not in group_names: | ||
msg = f'Group {group} was not found in the list: {group_names}' | ||
raise ValueError(msg) | ||
index = (group_names == group).idxmax() | ||
pos = float(self._data.loc[index, 'pos']) | ||
# round the position | ||
return round(pos / self.POSITION_TOLERANCE) * self.POSITION_TOLERANCE | ||
|
||
@property | ||
def artist_width(self) -> float: | ||
return float(self._artist_width) | ||
|
||
def compatible_width(self, width: float) -> bool: | ||
"""Check if the rectangle width is smaller than the artist width.""" | ||
return abs(width) <= 1.1 * self._artist_width | ||
|
||
def iter_groups(self) -> Iterator[tuple[TupleGroup, str, float]]: | ||
"""Iterate the groups and return a tuple (group_tuple, group_label, group_position).""" | ||
yield from self._data[['group', 'label', 'pos']].itertuples(index=False, name=None) |
Oops, something went wrong.