diff --git a/piel/visual/piel_fast.rcParams b/piel/visual/piel_fast.rcParams index bc828910..52ed4c15 100644 --- a/piel/visual/piel_fast.rcParams +++ b/piel/visual/piel_fast.rcParams @@ -404,7 +404,7 @@ axes.axisbelow: line # draw axis gridlines and ticks: #axes.unicode_minus: True # use Unicode for the minus symbol rather than hyphen. See # https://en.wikipedia.org/wiki/Plus_and_minus_signs#Character_codes -axes.prop_cycle: cycler('color', ['1982C4', '43A8E7', 'C6B7DA', '6B4C93', 'B79174', 'A17500', '8C564B', '32490E', '6A5541', '87C7F0']) +axes.prop_cycle: cycler('color', ['1982C4', '6B4C93', 'C6B7DA', 'B79174', 'A17500', '8C564B', '32490E', '6A5541', '87C7F0', '43A8E7']) # color cycle for plot lines as list of string color specs: # single letter, long name, or web-style hex # As opposed to all other parameters in this file, the color diff --git a/piel/visual/plot/basic.py b/piel/visual/plot/basic.py index 290f1717..b136bc92 100644 --- a/piel/visual/plot/basic.py +++ b/piel/visual/plot/basic.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd from typing import List, Optional, Any +from .position import create_axes_per_figure __all__ = [ "plot_simple", @@ -15,9 +16,12 @@ def plot_simple( ylabel: Optional[str] = None, xlabel: Optional[str] = None, fig: Optional[Any] = None, - ax: Optional[Any] = None, + axs: Optional[list[Any]] = None, title: Optional[str] = None, + plot_args: list = None, plot_kwargs: dict = None, + *args, + **kwargs, ) -> tuple: """ Plot a simple line graph. This function abstracts the basic files representation while @@ -30,7 +34,7 @@ def plot_simple( ylabel (Optional[str], optional): Y axis label. Defaults to None. xlabel (Optional[str], optional): X axis label. Defaults to None. fig (Optional[plt.Figure], optional): Matplotlib figure. Defaults to None. - ax (Optional[plt.Axes], optional): Matplotlib axes. Defaults to None. + axs (Optional[list[plt.Axes]], optional): Matplotlib axes. Defaults to None. title (Optional[str], optional): Title of the plot. Defaults to None. *args: Additional arguments passed to plt.plot(). **kwargs: Additional keyword arguments passed to plt.plot(). @@ -38,15 +42,18 @@ def plot_simple( Returns: Tuple[plt.Figure, plt.Axes]: The figure and axes of the plot. """ - import matplotlib.pyplot as plt - if fig is None and ax is None: - fig, ax = plt.subplots() + if fig is None and axs is None: + fig, axs = create_axes_per_figure(rows=1, columns=1) if plot_kwargs is None: plot_kwargs = dict() - ax.plot(x_data, y_data, label=label, **plot_kwargs) + if plot_args is None: + plot_args = list() + + ax = axs[0] + ax.plot(x_data, y_data, label=label, *plot_args, **plot_kwargs) if ylabel is not None: ax.set_ylabel(ylabel) @@ -67,7 +74,7 @@ def plot_simple( fig.tight_layout() - return fig, ax + return fig, axs def plot_simple_multi_row( diff --git a/pyproject.toml b/pyproject.toml index 82ef9162..9274b1a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,7 +129,7 @@ exclude = "docs" [tools.pytest.ini_options] addopts = "-rav" minversion = "6.0" -testpaths = ["tests"] +testpaths = ["tests/"] [tool.ruff.lint.per-file-ignores]