diff --git a/scikit_posthocs/_plotting.py b/scikit_posthocs/_plotting.py index 20be9ca..1d6efa1 100644 --- a/scikit_posthocs/_plotting.py +++ b/scikit_posthocs/_plotting.py @@ -364,6 +364,7 @@ def critical_difference_diagram( ranks: Union[dict, Series], sig_matrix: DataFrame, *, + cd: Optional[float] = None, ax: Optional[Axes] = None, label_fmt_left: str = "{label} ({rank:.2g})", label_fmt_right: str = "({rank:.2g}) {label}", @@ -412,6 +413,9 @@ def critical_difference_diagram( The object in which the plot will be built. Gets the current Axes by default (if None is passed). + cd : float, optional + Critical difference value to be plotted as a horizontal line above the axis. + label_fmt_left : str, optional The format string to apply to the labels on the left side. The keywords label and rank can be used to specify the sample/estimator name and @@ -588,6 +592,22 @@ def plot_items(points, xpos, label_fmt, color_palette, label_props): ) ypos -= 1 + def plot_cd(cd, label_props, crossbar_props): + """Plot CD horizontal line and label above OX axis if CD is passed.""" + if cd is not None: + x_min, x_max = ax.get_xlim() + ax.axhline(0.5, 0, (cd / (x_max - x_min)), **crossbar_props) + ax.text((x_min + cd / 2), 0.65, f"CD = {cd:.2g}", **label_props) + + plot_cd( + cd, + crossbar_props=crossbar_props, + label_props={ + "ha": "center", + **label_props, + } + ) + plot_items( points_left, xpos=points_left.iloc[0] - text_h_margin,