Skip to content

Commit

Permalink
add compat file for different seaborn versions
Browse files Browse the repository at this point in the history
compatibility with seaborn>=0.11

pass all tests
  • Loading branch information
getzze committed Jul 1, 2024
1 parent 8f148c1 commit 649e375
Show file tree
Hide file tree
Showing 7 changed files with 1,114 additions and 289 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
numpy>=1.12.1
seaborn>=0.9.0,<0.12
seaborn>=0.9.0
matplotlib>=2.2.2
pandas>=0.23.0,<2.0.0
pandas>=0.23.0
scipy>=1.1.0
statsmodels
packaging
7 changes: 5 additions & 2 deletions statannotations/Annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def print_labels_and_content(self, sep=" vs. "):

def check_data_stat_result(self):
if not isinstance(self.data, StatResult):
warnings.warn("Annotation data has incorrect class." +
"Should be StatResult. Cannot annotate current pair.")
msg = (
"Cannot annotate current pair. Annotation data has incorrect "
f"class, should be StatResult: {type(self.data)}"
)
warnings.warn(msg)
return False
return True
198 changes: 144 additions & 54 deletions statannotations/_GroupsPositions.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,160 @@
from __future__ import annotations

from collections.abc import Iterator, Sequence
import itertools
from typing import TYPE_CHECKING
import warnings

import numpy as np
import pandas as pd

if TYPE_CHECKING:
from .compat import TupleGroup, TGroupValue, THueValue


def get_group_names_and_labels(
group_names: Sequence[TGroupValue],
hue_names: Sequence[THueValue],
) -> tuple[list[TupleGroup], list[str]]:
tuple_group_names: list[TupleGroup]
if len(hue_names) == 0:
tuple_group_names = [(name,) for name in group_names]
labels = [str(name) for name in group_names]

from statannotations.utils import get_closest
else:
labels = []
tuple_group_names = []
for group_name, hue_name in itertools.product(group_names, hue_names):
tuple_group_names.append((group_name, hue_name))
labels.append(f'{group_name}_{hue_name}')

return tuple_group_names, labels


class _GroupsPositions:
def __init__(self, plotter, group_names):
self._plotter = plotter
self._hue_names = self._plotter.hue_names
POSITION_TOLERANCE: float = 0.1

if self._hue_names is not None:
nb_hues = len(self._hue_names)
if nb_hues == 1:
raise ValueError(
"Using hues with only one hue is not supported.")
width: float
tuple_group_names: list[TupleGroup]
labels: list[str]
_data: pd.DataFrame

self.hue_offsets = self._plotter.hue_offsets
self._axis_units = self.hue_offsets[1] - self.hue_offsets[0]
def __init__(
self,
group_names: Sequence[TGroupValue],
hue_names: Sequence[THueValue],
*,
dodge: bool = True,
gap: float = 0.0,
width: float = 0.8,
native_group_offsets: Sequence | None = None,
) -> None:
self.gap = gap
self.dodge = dodge

self._groups_positions = {
np.round(self.get_group_axis_position(group_name), 1): group_name
for group_name in group_names
}
self._group_names = group_names
self._hue_names = hue_names
self.use_hue = len(hue_names) == 0

self._groups_positions_list = sorted(self._groups_positions.keys())
# Compute the coordinates of the groups (without hue) and the width
self.group_offsets, self.width = self._set_group_offsets(
group_names, native_group_offsets, width
)
# Create the tuple (group, hue) and the labels
self.tuple_group_names, self.labels = get_group_names_and_labels(group_names, hue_names)

if self._hue_names is None:
self._axis_units = ((max(list(self._groups_positions.keys())) + 1)
/ len(self._groups_positions))
# Create dataframe with the groups, labels and positions
# this should be done last, when the other attributes are defined
self._data, self._artist_width = self._set_data(dodge=dodge, gap=gap)

self._axis_ranges = {
(pos - self._axis_units / 2,
pos + self._axis_units / 2,
pos): group_name
for pos, group_name in self._groups_positions.items()}
def _set_group_offsets(
self,
group_names: Sequence,
native_group_offsets: Sequence | None,
width: float,
) -> tuple[Sequence, float]:
"""Set the group offsets from native scale and scale the width."""
group_offsets = list(range(len(group_names)))
if native_group_offsets is not None:
curated_offsets = [v for v in native_group_offsets]
if len(curated_offsets) != len(group_names):
msg = (
'The values of the categories with "native_scale=True" do not correspond '
'to the category names. Maybe some values are not finite?'
)
warnings.warn(msg)
else:
group_offsets = curated_offsets
if len(curated_offsets) > 1:
native_width = np.min(np.diff(curated_offsets))
width *= native_width

@property
def axis_positions(self):
return self._groups_positions
return group_offsets, width

@property
def axis_units(self):
return self._axis_units
def _set_data(self, dodge: bool, gap: float) -> tuple[pd.DataFrame, float]:
n_repeat = max(len(self._hue_names), 1)
group_positions = np.array(self.group_offsets)
positions = np.repeat(group_positions, n_repeat)
artist_width = float(self.width)
data = pd.DataFrame(
{
'group': self.tuple_group_names,
'label': self.labels,
'pos': positions,
},
)
if dodge and self.use_hue:
n_hues = max(len(self._hue_names), 1)
artist_width /= n_hues
# evenly space range centered in zero (subtracting the mean)
offset = artist_width * (np.arange(n_hues) - (n_hues - 1) / 2)
tiled_offset = np.tile(offset, len(self._group_names))
data['pos'] += tiled_offset
if gap and gap >= 0 and gap <= 1:
artist_width *= 1 - gap

def get_axis_pos_location(self, pos):
"""
Finds the x-axis location of a categorical variable
"""
for axis_range in self._axis_ranges:
if (pos >= axis_range[0]) & (pos <= axis_range[1]):
return axis_range[2]
return data, artist_width

def get_group_axis_position(self, group):
"""
group_name can be either a name "cat" or a tuple ("cat", "hue")
def find_group_at_pos(self, pos: float, *, verbose: bool = False) -> TupleGroup | None:
positions = self._data['pos']
if len(positions) == 0:
return None
# Get the index of the closest position
index = (positions - pos).abs().idxmin()
found_pos = positions.loc[index]

if verbose and abs(found_pos - pos) > self.POSITION_TOLERANCE:
# The requested position is not an artist position
msg = (
'Invalid x-position found. Are the same parameters passed to '
'seaborn and statannotations calls? Or are there few data points? '
f'The closest group position to {pos} is {found_pos}'
)
warnings.warn(msg)
return self._data.loc[index, 'group']

def get_group_axis_position(self, group: TupleGroup) -> float:
"""Get the position of the group.
group_name can be either a tuple ("group",) or a tuple ("group", "hue")
"""
if self._plotter.plot_hues is None:
cat = group
hue_offset = 0
else:
cat = group[0]
hue_level = group[1]
hue_offset = self._plotter.hue_offsets[
self._plotter.hue_names.index(hue_level)]

group_pos = self._plotter.group_names.index(cat) + hue_offset
return group_pos

def find_closest(self, pos):
return get_closest(list(self._groups_positions_list), pos)
group_names = self._data['group']
if group not in group_names:
msg = f'Group {group} was not found in the list: {group_names}'
raise ValueError(msg)
index = (group_names == group).idxmax()
pos = float(self._data.loc[index, 'pos'])
# round the position
return round(pos / self.POSITION_TOLERANCE) * self.POSITION_TOLERANCE

@property
def artist_width(self) -> float:
return float(self._artist_width)

def compatible_width(self, width: float) -> bool:
"""Check if the rectangle width is smaller than the artist width."""
return abs(width) <= 1.1 * self._artist_width

def iter_groups(self) -> Iterator[tuple[TupleGroup, str, float]]:
"""Iterate the groups and return a tuple (group_tuple, group_label, group_position)."""
yield from self._data[['group', 'label', 'pos']].itertuples(index=False, name=None)
Loading

0 comments on commit 649e375

Please sign in to comment.