Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions src/arviz_plots/style.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
"""Style/templating helpers."""
import os
from pathlib import Path

from arviz_base import rcParams


def use(name):
"""Set an arviz style as the default style/template for all available backends.

The style will be set for all backends that support it and have it available.
The supported backends are Matplotlib, Plotly and Bokeh.
You can use ``arviz_plots.style.available()`` to check which styles are available.
The ones that work for all backends are listed under the 'common' key.

Parameters
----------
name : str
Name of the style to be set as default.

Raises
------
ValueError
If the style with the given name is not found.

"""
ok = False

Expand All @@ -33,38 +44,64 @@ def use(name):
pass

try:
if name in ["arviz-cetrino", "arviz-tenui", "arviz-tumma", "arviz-variat", "arviz-vibrant"]:
from bokeh.io import curdoc
from bokeh.themes import Theme
from bokeh.io import curdoc
from bokeh.themes import Theme

path = os.path.dirname(os.path.abspath(__file__))
curdoc().theme = Theme(filename=f"{path}/styles/{name}.yml")
path = Path(__file__).parent / "styles" / f"{name}.yml"
if path.exists():
curdoc().theme = Theme(filename=str(path))
ok = True
except (ImportError, FileNotFoundError):
except ImportError:
pass

if not ok:
raise ValueError(f"Style {name} not found.")


def available():
"""List available styles."""
"""List available styles.

If multiple backends are installed, it also lists styles common to
all backends under the 'common' key.

Returns
-------
dict
Dictionary with backend names as keys and list of available styles as values.
"""
styles = {}

n_backends = 0
try:
import matplotlib.pyplot as plt

styles["matplotlib"] = plt.style.available
n_backends += 1
except ImportError:
pass

try:
import plotly.io as pio

styles["plotly"] = list(pio.templates)
n_backends += 1
except ImportError:
pass

try:
from bokeh.themes import built_in_themes

path = Path(__file__).parent / "styles"
custom = [file.stem for file in path.glob("*.yml") if path.exists()]
styles["bokeh"] = list(built_in_themes) + custom
n_backends += 1
except ImportError:
pass

if n_backends > 1:
common = set.intersection(*(set(v) for v in styles.values()))
styles["common"] = list(common)

return styles


Expand All @@ -81,7 +118,7 @@ def get(name, backend=None):
"""
if backend is None:
backend = rcParams["plot.backend"]
if backend not in ["matplotlib", "plotly"]:
if backend not in ["matplotlib", "plotly", "bokeh"]:
raise ValueError(f"Default styles/templates are not supported for Backend {backend}")

if backend == "matplotlib":
Expand All @@ -96,4 +133,11 @@ def get(name, backend=None):
if name in pio.templates:
return pio.templates[name]

elif backend == "bokeh":
from bokeh.themes import Theme

path = Path(__file__).parent / "styles" / f"{name}.yml"
if path.exists():
return Theme(filename=path)

raise ValueError(f"Style {name} not found.")