16
16
import matplotlib .axes as axes
17
17
import plotnine
18
18
19
- from contextlib import suppress
19
+ from contextlib import contextmanager , suppress
20
20
from types import SimpleNamespace as NS
21
+ from unittest .mock import patch
21
22
from matplotlib .offsetbox import AnchoredOffsetbox
22
23
from matplotlib .transforms import Bbox , TransformedBbox , Affine2D
23
24
@@ -210,11 +211,58 @@ def _reset_ggplot_legend(bricks):
210
211
else :
211
212
pass
212
213
213
- def overwrite_plotnine ():
214
+ def _needs_plotnine_ggplot_draw_patch ():
215
+ """Implementation detail for patched_plotnine, for internal use."""
214
216
import plotnine
215
- plotnine .ggplot .draw = mp9 .draw
217
+ plotnine_version = plotnine .__version__
218
+
219
+ return StrictVersion (plotnine_version ) >= StrictVersion ("0.12" )
220
+
221
+ @contextmanager
222
+ def patched_plotnine ():
223
+ """
224
+
225
+ Temporarily patch plot9 for patchworklib compatibility. Can be used as a
226
+ context manager or a decorator.
227
+
228
+ Examples
229
+ -------
230
+ >>> with patched_plotnine():
231
+ ... pw.load_ggplot(
232
+ ... p9.ggplot(data, p9.aes(x="x", y="y", fill="fill")),
233
+ ... )
234
+
235
+ Example use as a context manager.
236
+
237
+ >>> @patched_plotnine()
238
+ >>> def custom_plot():
239
+ ... pw.load_ggplot(
240
+ ... p9.ggplot(data, p9.aes(x="x", y="y", fill="fill")),
241
+ ... )
242
+ >>> custom_plot()
243
+
244
+ Example use as a decorator.
245
+ """
246
+ if _needs_plotnine_ggplot_draw_patch ():
247
+ with patch ("plotnine.ggplot.ggplot.draw" , mp9 .draw ):
248
+ yield
249
+ else :
250
+ yield
251
+
252
+ def overwrite_plotnine ():
253
+ """
254
+
255
+ Modify plot9 for patchworklib compatibility.
256
+
257
+ See Also
258
+ --------
259
+ patched_plotnine : Context manager that applies then reverses plotnine
260
+ patches.
261
+ """
262
+ patched_plotnine ().__enter__ ()
216
263
217
- def load_ggplot (ggplot = None , figsize = None ):
264
+ @patched_plotnine ()
265
+ def load_ggplot (ggplot = None , figsize = None ):
218
266
"""
219
267
220
268
Convert a plotnine plot object to a patchworklib.Bricks object.
@@ -231,6 +279,8 @@ def load_ggplot(ggplot=None, figsize=None):
231
279
patchworklib.Bricks object.
232
280
233
281
"""
282
+ import plotnine
283
+ plotnine_version = plotnine .__version__
234
284
235
285
def draw_labels (bricks , gori , gcp , figsize ):
236
286
get_property = gcp .theme .themeables .property
@@ -453,10 +503,6 @@ def draw_title(bricks, gori, gcp, figsize):
453
503
for ax in gori .axs :
454
504
gori .theme .themeables ['plot_title' ].apply (ax )
455
505
456
- import plotnine
457
- plotnine_version = plotnine .__version__
458
- if StrictVersion (plotnine_version ) >= StrictVersion ("0.12" ):
459
- overwrite_plotnine ()
460
506
461
507
#save_original_position
462
508
global _basefigure
@@ -659,33 +705,77 @@ def draw_title(bricks, gori, gcp, figsize):
659
705
660
706
return return_obj
661
707
708
+ @contextmanager
709
+ def patched_axisgrid ():
710
+ """
711
+
712
+ Temporarily patch seaborn.axisgrid methods with patchworklib counterparts.
713
+ This allows for custom behaviors in seaborn's grid objects like FacetGrid,
714
+ PairGrid, JointGrid, and ClusterGrid, particularly useful when integrating
715
+ seaborn's plotting with patchworklib's functionalities. Can be used as a
716
+ context manager or a decorator.
717
+
718
+ Examples
719
+ -------
720
+ >>> with patched_axisgrid():
721
+ ... pw.load_seabornobj(
722
+ ... sns.jointplot(x="x", y="y", data=data),
723
+ ... )
724
+
725
+ Example use as a context manager.
726
+
727
+ >>> @patched_axisgrid()
728
+ >>> def custom_plot():
729
+ ... pw.load_seabornobj(
730
+ ... sns.jointplot(x="x", y="y", data=data),
731
+ ... )
732
+ >>> custom_plot()
733
+
734
+ Example use as a decorator.
735
+ """
736
+ # patch("sns.pairplot", mg.pairplot)
737
+ with patch .object (
738
+ sns .axisgrid .Grid , "_figure" , _basefigure , create = True
739
+ ), patch (
740
+ "seaborn.axisgrid.Grid.add_legend" , mg .add_legend
741
+ ), patch (
742
+ "seaborn.axisgrid.FacetGrid.__init__" , mg .__init_for_facetgrid__
743
+ ), patch (
744
+ "seaborn.axisgrid.FacetGrid.despine" , mg .despine
745
+ ), patch (
746
+ "seaborn.axisgrid.PairGrid.__init__" , mg .__init_for_pairgrid__
747
+ ), patch (
748
+ "seaborn.axisgrid.JointGrid.__init__" , mg .__init_for_jointgrid__
749
+ ), patch (
750
+ "seaborn.matrix.ClusterGrid.__init__" , mg .__init_for_clustergrid__
751
+ ), patch (
752
+ "seaborn.matrix.ClusterGrid.__setattr__" , mg .__setattr_for_clustergrid__
753
+ ), patch (
754
+ "seaborn.matrix.ClusterGrid.plot" , mg .__plot_for_clustergrid__
755
+ ):
756
+ yield
757
+
662
758
def overwrite_axisgrid ():
663
759
"""
664
760
665
- Overwrite `__init__` functions in seaborn.axisgrid.FacetGrid,
761
+ Overwrite `__init__` functions in seaborn.axisgrid.FacetGrid,
666
762
seaborn.axisgrid.PairGrid and seaborn.axisgrid.JointGrid.
667
- The function changes the figure object given in the `__init__` functions of the
668
- axisgrid class objects, which is used for drawing plots, to `_basefigure
669
- in the patchworklib. If you want to import plots generated baseon
670
- seabron.axisgrid.xxGrid objects as patchworklib.Brick(s) object by using
763
+ The function changes the figure object given in the `__init__` functions of the
764
+ axisgrid class objects, which is used for drawing plots, to `_basefigure
765
+ in the patchworklib. If you want to import plots generated baseon
766
+ seabron.axisgrid.xxGrid objects as patchworklib.Brick(s) object by using
671
767
`load_seaborngrid` function, you should execute the function in advance.
672
768
673
769
Returns
674
770
-------
675
771
None.
676
772
677
- """
678
-
679
- #sns.pairplot = mg.pairplot
680
- sns .axisgrid .Grid ._figure = _basefigure
681
- sns .axisgrid .Grid .add_legend = mg .add_legend
682
- sns .axisgrid .FacetGrid .__init__ = mg .__init_for_facetgrid__
683
- sns .axisgrid .FacetGrid .despine = mg .despine
684
- sns .axisgrid .PairGrid .__init__ = mg .__init_for_pairgrid__
685
- sns .axisgrid .JointGrid .__init__ = mg .__init_for_jointgrid__
686
- sns .matrix .ClusterGrid .__init__ = mg .__init_for_clustergrid__
687
- sns .matrix .ClusterGrid .__setattr__ = mg .__setattr_for_clustergrid__
688
- sns .matrix .ClusterGrid .plot = mg .__plot_for_clustergrid__
773
+ See Also
774
+ --------
775
+ patched_axisgrid : Context manager that applies then reverses axisgrid
776
+ patches.
777
+ """
778
+ patched_axisgrid ().__enter__ ()
689
779
690
780
def load_seabornobj (g , label = None , labels = None , figsize = (3 ,3 )):
691
781
"""
0 commit comments