Skip to content

Commit 76c7f88

Browse files
authored
Merge pull request #50 from mmore500/main
Create context managers for temporary plot9/seaborn patching
2 parents 4428f24 + 010008a commit 76c7f88

File tree

3 files changed

+201
-25
lines changed

3 files changed

+201
-25
lines changed

API.md

+12
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,18 @@
2828

2929
`None`
3030

31+
- ### **`patched_axisgrid()`**
32+
33+
[Context manager](https://docs.python.org/3/reference/compound_stmts.html#with)/[decorator](https://docs.python.org/3/glossary.html#term-decorator)
34+
interface for `overwrite_axisgrid` patching that reverts changes when leaving
35+
`with`/function scope.
36+
37+
#### Returns
38+
39+
Context manager (i.e., `with patched_axisgrid():`) or decorator (i.e.,
40+
`@patched_axisgrid()`) that temporarily patches seaborn for patchworklib
41+
compatibility.
42+
3143
- ### **`load_seabornobj(g, label=None, labels=None, figsize=(3, 3))`**
3244

3345
Load a seaborn plot generated based on `seaborn._core.plot.Plotter` class.

patchworklib/patchworklib.py

+115-25
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
import matplotlib.axes as axes
1717
import plotnine
1818

19-
from contextlib import suppress
19+
from contextlib import contextmanager, suppress
2020
from types import SimpleNamespace as NS
21+
from unittest.mock import patch
2122
from matplotlib.offsetbox import AnchoredOffsetbox
2223
from matplotlib.transforms import Bbox, TransformedBbox, Affine2D
2324

@@ -210,11 +211,58 @@ def _reset_ggplot_legend(bricks):
210211
else:
211212
pass
212213

213-
def overwrite_plotnine():
214+
def _needs_plotnine_ggplot_draw_patch():
215+
"""Implementation detail for patched_plotnine, for internal use."""
214216
import plotnine
215-
plotnine.ggplot.draw = mp9.draw
217+
plotnine_version = plotnine.__version__
218+
219+
return StrictVersion(plotnine_version) >= StrictVersion("0.12")
220+
221+
@contextmanager
222+
def patched_plotnine():
223+
"""
224+
225+
Temporarily patch plot9 for patchworklib compatibility. Can be used as a
226+
context manager or a decorator.
227+
228+
Examples
229+
-------
230+
>>> with patched_plotnine():
231+
... pw.load_ggplot(
232+
... p9.ggplot(data, p9.aes(x="x", y="y", fill="fill")),
233+
... )
234+
235+
Example use as a context manager.
236+
237+
>>> @patched_plotnine()
238+
>>> def custom_plot():
239+
... pw.load_ggplot(
240+
... p9.ggplot(data, p9.aes(x="x", y="y", fill="fill")),
241+
... )
242+
>>> custom_plot()
243+
244+
Example use as a decorator.
245+
"""
246+
if _needs_plotnine_ggplot_draw_patch():
247+
with patch("plotnine.ggplot.ggplot.draw", mp9.draw):
248+
yield
249+
else:
250+
yield
251+
252+
def overwrite_plotnine():
253+
"""
254+
255+
Modify plot9 for patchworklib compatibility.
256+
257+
See Also
258+
--------
259+
patched_plotnine : Context manager that applies then reverses plotnine
260+
patches.
261+
"""
262+
patched_plotnine().__enter__()
216263

217-
def load_ggplot(ggplot=None, figsize=None):
264+
@patched_plotnine()
265+
def load_ggplot(ggplot=None, figsize=None):
218266
"""
219267
220268
Convert a plotnine plot object to a patchworklib.Bricks object.
@@ -231,6 +279,8 @@ def load_ggplot(ggplot=None, figsize=None):
231279
patchworklib.Bricks object.
232280
233281
"""
282+
import plotnine
283+
plotnine_version = plotnine.__version__
234284

235285
def draw_labels(bricks, gori, gcp, figsize):
236286
get_property = gcp.theme.themeables.property
@@ -453,10 +503,6 @@ def draw_title(bricks, gori, gcp, figsize):
453503
for ax in gori.axs:
454504
gori.theme.themeables['plot_title'].apply(ax)
455505

456-
import plotnine
457-
plotnine_version = plotnine.__version__
458-
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
459-
overwrite_plotnine()
460506

461507
#save_original_position
462508
global _basefigure
@@ -659,33 +705,77 @@ def draw_title(bricks, gori, gcp, figsize):
659705

660706
return return_obj
661707

708+
@contextmanager
709+
def patched_axisgrid():
710+
"""
711+
712+
Temporarily patch seaborn.axisgrid methods with patchworklib counterparts.
713+
This allows for custom behaviors in seaborn's grid objects like FacetGrid,
714+
PairGrid, JointGrid, and ClusterGrid, particularly useful when integrating
715+
seaborn's plotting with patchworklib's functionalities. Can be used as a
716+
context manager or a decorator.
717+
718+
Examples
719+
-------
720+
>>> with patched_axisgrid():
721+
... pw.load_seabornobj(
722+
... sns.jointplot(x="x", y="y", data=data),
723+
... )
724+
725+
Example use as a context manager.
726+
727+
>>> @patched_axisgrid()
728+
>>> def custom_plot():
729+
... pw.load_seabornobj(
730+
... sns.jointplot(x="x", y="y", data=data),
731+
... )
732+
>>> custom_plot()
733+
734+
Example use as a decorator.
735+
"""
736+
# patch("sns.pairplot", mg.pairplot)
737+
with patch.object(
738+
sns.axisgrid.Grid, "_figure", _basefigure, create=True
739+
), patch(
740+
"seaborn.axisgrid.Grid.add_legend", mg.add_legend
741+
), patch(
742+
"seaborn.axisgrid.FacetGrid.__init__", mg.__init_for_facetgrid__
743+
), patch(
744+
"seaborn.axisgrid.FacetGrid.despine", mg.despine
745+
), patch(
746+
"seaborn.axisgrid.PairGrid.__init__", mg.__init_for_pairgrid__
747+
), patch(
748+
"seaborn.axisgrid.JointGrid.__init__", mg.__init_for_jointgrid__
749+
), patch(
750+
"seaborn.matrix.ClusterGrid.__init__", mg.__init_for_clustergrid__
751+
), patch(
752+
"seaborn.matrix.ClusterGrid.__setattr__", mg.__setattr_for_clustergrid__
753+
), patch(
754+
"seaborn.matrix.ClusterGrid.plot", mg.__plot_for_clustergrid__
755+
):
756+
yield
757+
662758
def overwrite_axisgrid():
663759
"""
664760
665-
Overwrite `__init__` functions in seaborn.axisgrid.FacetGrid,
761+
Overwrite `__init__` functions in seaborn.axisgrid.FacetGrid,
666762
seaborn.axisgrid.PairGrid and seaborn.axisgrid.JointGrid.
667-
The function changes the figure object given in the `__init__` functions of the
668-
axisgrid class objects, which is used for drawing plots, to `_basefigure
669-
in the patchworklib. If you want to import plots generated baseon
670-
seabron.axisgrid.xxGrid objects as patchworklib.Brick(s) object by using
763+
The function changes the figure object given in the `__init__` functions of the
764+
axisgrid class objects, which is used for drawing plots, to `_basefigure
765+
in the patchworklib. If you want to import plots generated baseon
766+
seabron.axisgrid.xxGrid objects as patchworklib.Brick(s) object by using
671767
`load_seaborngrid` function, you should execute the function in advance.
672768
673769
Returns
674770
-------
675771
None.
676772
677-
"""
678-
679-
#sns.pairplot = mg.pairplot
680-
sns.axisgrid.Grid._figure = _basefigure
681-
sns.axisgrid.Grid.add_legend = mg.add_legend
682-
sns.axisgrid.FacetGrid.__init__ = mg.__init_for_facetgrid__
683-
sns.axisgrid.FacetGrid.despine = mg.despine
684-
sns.axisgrid.PairGrid.__init__ = mg.__init_for_pairgrid__
685-
sns.axisgrid.JointGrid.__init__ = mg.__init_for_jointgrid__
686-
sns.matrix.ClusterGrid.__init__ = mg.__init_for_clustergrid__
687-
sns.matrix.ClusterGrid.__setattr__ = mg.__setattr_for_clustergrid__
688-
sns.matrix.ClusterGrid.plot = mg.__plot_for_clustergrid__
773+
See Also
774+
--------
775+
patched_axisgrid : Context manager that applies then reverses axisgrid
776+
patches.
777+
"""
778+
patched_axisgrid().__enter__()
689779

690780
def load_seabornobj(g, label=None, labels=None, figsize=(3,3)):
691781
"""

tests/test_patchworklib.py

+74
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,77 @@ def test_sns_and_p9(tmp_path: Path):
4848
result_file = tmp_path / "g.png"
4949
g.savefig(result_file)
5050
assert result_file.exists()
51+
52+
53+
@pw.patched_axisgrid()
54+
def _make_seabornobj():
55+
iris = sns.load_dataset("iris")
56+
tips = sns.load_dataset("tips")
57+
58+
# An lmplot
59+
g0 = sns.lmplot(
60+
x="total_bill", y="tip", hue="smoker", data=tips, palette=dict(Yes="g", No="m")
61+
)
62+
g0 = pw.load_seaborngrid(g0, label="g0")
63+
64+
# A Pairplot
65+
g1 = sns.pairplot(iris, hue="species")
66+
g1 = pw.load_seaborngrid(g1, label="g1", figsize=(6, 6))
67+
68+
# A relplot
69+
g2 = sns.relplot(
70+
data=tips,
71+
x="total_bill",
72+
y="tip",
73+
col="time",
74+
hue="time",
75+
size="size",
76+
style="sex",
77+
palette=["b", "r"],
78+
sizes=(10, 100),
79+
)
80+
g2.set_titles("")
81+
g2 = pw.load_seaborngrid(g2, label="g2")
82+
83+
# A JointGrid
84+
g3 = sns.jointplot(
85+
data=iris, x="sepal_width", y="petal_length", kind="kde", space=0
86+
)
87+
g3 = pw.load_seaborngrid(g3, label="g3")
88+
89+
composite = (((g0/g3)["g0"]|g1)["g1"]/g2)
90+
return composite
91+
92+
93+
def test_load_seabornobj(tmp_path: Path):
94+
composite = _make_seabornobj()
95+
96+
result_file = tmp_path / "composite.png"
97+
composite.savefig(result_file)
98+
assert result_file.exists()
99+
100+
101+
@pw.patched_axisgrid() # duplicate patch wrapper
102+
def test_patch_nesting(tmp_path: Path):
103+
composite = _make_seabornobj()
104+
105+
result_file = tmp_path / "composite.png"
106+
composite.savefig(result_file)
107+
assert result_file.exists()
108+
109+
110+
def test_patched_axisgrid():
111+
with pw.patched_axisgrid():
112+
assert hasattr(sns.axisgrid.Grid, "_figure")
113+
assert sns.axisgrid.FacetGrid.add_legend is pw.modified_grid.add_legend
114+
115+
assert not hasattr(sns.axisgrid.Grid, "_figure")
116+
assert sns.axisgrid.FacetGrid.add_legend is not pw.modified_grid.add_legend
117+
118+
119+
def test_patched_plotnine():
120+
with pw.patched_plotnine():
121+
if pw.patchworklib._needs_plotnine_ggplot_draw_patch:
122+
assert p9.ggplot.draw is pw.modified_plotnine.draw
123+
124+
assert p9.ggplot.draw is not pw.modified_plotnine.draw

0 commit comments

Comments
 (0)