Skip to content

Commit

Permalink
added axes share in canvas
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Jan 5, 2024
1 parent 36fff85 commit a789b0f
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 30 deletions.
28 changes: 20 additions & 8 deletions atom/plots/basefigure.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,19 @@ class BaseFigure:
cols: int, default=1
Number of subplot columns in the canvas.
horizontal_spacing: float, default=0.05
sharex: bool, default=False
If True, hide the label and ticks from non-border subplots
on the x-axis.
sharey: bool, default=False
If True, hide the label and ticks from non-border subplots
on the y-axis.
hspace: float, default=0.05
Space between subplot rows in normalized plot coordinates.
The spacing is relative to the figure's size.
vertical_spacing: float, default=0.07
vspace: float, default=0.07
Space between subplot cols in normalized plot coordinates.
The spacing is relative to the figure's size.
Expand All @@ -66,17 +74,21 @@ def __init__(
rows: IntLargerZero = 1,
cols: IntLargerZero = 1,
*,
horizontal_spacing: FloatZeroToOneExc = 0.05,
vertical_spacing: FloatZeroToOneExc = 0.07,
sharex: Bool = False,
sharey: Bool = False,
hspace: FloatZeroToOneExc = 0.05,
vspace: FloatZeroToOneExc = 0.07,
palette: str | Sequence[str] = "Prism",
is_canvas: Bool = False,
backend: PlotBackend = "plotly",
create_figure: Bool = True,
):
self.rows = rows
self.cols = cols
self.horizontal_spacing = horizontal_spacing
self.vertical_spacing = vertical_spacing
self.sharex = sharex
self.sharey = sharey
self.hspace = hspace
self.vspace = vspace
if isinstance(palette, str):
self._palette = getattr(px.colors.qualitative, palette)
self.palette = cycle(self._palette)
Expand Down Expand Up @@ -235,8 +247,8 @@ def get_axes(
self.axes += 1

# Calculate the distance between subplots
x_offset = divide(self.horizontal_spacing, (self.cols - 1))
y_offset = divide(self.vertical_spacing, (self.rows - 1))
x_offset = divide(self.hspace, (self.cols - 1))
y_offset = divide(self.vspace, (self.rows - 1))

# Calculate the size of the subplot
x_size = (1 - ((x_offset * 2) * (self.cols - 1))) / self.cols
Expand Down
76 changes: 54 additions & 22 deletions atom/plots/baseplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,16 +531,33 @@ def _plot(
fig = fig or BasePlot._fig.figure
if isinstance(fig, go.Figure):
if isinstance(ax, tuple):
# Hide the axis' label and ticks from non-border subplots
if not BasePlot._fig.sharex or self._fig.grid[0] == self._fig.rows:
fig.update_layout(
{
f"{ax[0]}_title": {
"text": kwargs.get("xlabel"),
"font_size": self.label_fontsize,
}
}
)
else:
fig.update_layout({f"{ax[0]}_showticklabels": False})

if not BasePlot._fig.sharey or self._fig.grid[1] == 1:
fig.update_layout(
{
f"{ax[1]}_title": {
"text": kwargs.get("ylabel"),
"font_size": self.label_fontsize,
}
}
)
else:
fig.update_layout({f"{ax[1]}_showticklabels": False})

fig.update_layout(
{
f"{ax[0]}_title": {
"text": kwargs.get("xlabel"),
"font_size": self.label_fontsize,
},
f"{ax[1]}_title": {
"text": kwargs.get("ylabel"),
"font_size": self.label_fontsize,
},
f"{ax[0]}_range": kwargs.get("xlim"),
f"{ax[1]}_range": kwargs.get("ylim"),
f"{ax[0]}_automargin": True,
Expand Down Expand Up @@ -692,8 +709,10 @@ def canvas(
rows: IntLargerZero = 1,
cols: IntLargerZero = 2,
*,
horizontal_spacing: FloatZeroToOneExc = 0.05,
vertical_spacing: FloatZeroToOneExc = 0.07,
sharex: Bool = False,
sharey: Bool = False,
hspace: FloatZeroToOneExc = 0.05,
vspace: FloatZeroToOneExc = 0.07,
title: str | dict[str, Any] | None = None,
legend: Legend | dict[str, Any] | None = "out",
figsize: tuple[IntLargerZero, IntLargerZero] | None = None,
Expand All @@ -714,11 +733,19 @@ def canvas(
cols: int, default=2
Number of plots in width.
horizontal_spacing: float, default=0.05
sharex: bool, default=False
If True, hide the label and ticks from non-border subplots
on the x-axis.
sharey: bool, default=False
If True, hide the label and ticks from non-border subplots
on the y-axis.
hspace: float, default=0.05
Space between subplot rows in normalized plot coordinates.
The spacing is relative to the figure's size.
vertical_spacing: float, default=0.07
vspace: float, default=0.07
Space between subplot cols in normalized plot coordinates.
The spacing is relative to the figure's size.
Expand Down Expand Up @@ -759,8 +786,10 @@ def canvas(
BasePlot._fig = BaseFigure(
rows=rows,
cols=cols,
horizontal_spacing=horizontal_spacing,
vertical_spacing=vertical_spacing,
sharex=sharex,
sharey=sharey,
hspace=hspace,
vspace=vspace,
palette=self.palette,
is_canvas=True,
)
Expand All @@ -779,11 +808,12 @@ def canvas(
display=display,
)

def reset_aesthetics(self):
@classmethod
def reset_aesthetics(cls):
"""Reset the plot [aesthetics][] to their default values."""
self._custom_layout = {}
self._custom_traces = {}
self._aesthetics = Aesthetics(
cls._custom_layout = {}
cls._custom_traces = {}
cls._aesthetics = Aesthetics(
palette=list(PALETTE),
title_fontsize=24,
label_fontsize=16,
Expand All @@ -792,7 +822,8 @@ def reset_aesthetics(self):
marker_size=8,
)

def update_layout(self, **kwargs):
@classmethod
def update_layout(cls, **kwargs):
"""Update the properties of the plot's layout.
Recursively update the structure of the original layout with
Expand All @@ -804,9 +835,10 @@ def update_layout(self, **kwargs):
Keyword arguments for the figure's [update_layout][] method.
"""
self._custom_layout = kwargs
cls._custom_layout = kwargs

def update_traces(self, **kwargs):
@classmethod
def update_traces(cls, **kwargs):
"""Update the properties of the plot's traces.
Recursively update the structure of the original traces with
Expand All @@ -818,4 +850,4 @@ def update_traces(self, **kwargs):
Keyword arguments for the figure's [update_traces][] method.
"""
self._custom_traces = kwargs
cls._custom_traces = kwargs
1 change: 1 addition & 0 deletions docs_sources/changelog/v5.x.x.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
* Transformations only on `y` are now accepted, e.g., `atom.scale(columns=-1)`.
* Full support for [pandas nullable dtypes](https://pandas.pydata.org/docs/user_guide/integer_na.html).
* The dataset can now be provided as callable.
* Subplots can now share axes on the [canvas][atomclassifier-canvas].
* The [save][atomclassifier-save] and [save_data][atomclassifier-save_data]
methods now accept [pathlib.Path][] objects as `filename`.
* Cleaner representation on hover for the [plot_timeline][] method.
Expand Down

0 comments on commit a789b0f

Please sign in to comment.