Skip to content

Commit

Permalink
Merge pull request #92 from Carifio24/ipyvolume
Browse files Browse the repository at this point in the history
Add support for ipyvolume viewers
  • Loading branch information
Carifio24 authored Oct 3, 2024
2 parents 15be720 + e120dd3 commit 478712c
Show file tree
Hide file tree
Showing 17 changed files with 179 additions and 49 deletions.
15 changes: 10 additions & 5 deletions glue_plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ def setup_jupyter():
from glue_jupyter.bqplot.image import BqplotImageView
from glue_jupyter.bqplot.profile import BqplotProfileView
from glue_jupyter.bqplot.scatter import BqplotScatterView
from glue_jupyter.ipyvolume import IpyvolumeScatterView, IpyvolumeVolumeView

from glue_jupyter.ipyvolume.common.viewer import IpyvolumeBaseView
print(IpyvolumeBaseView.tools)
BqplotHistogramView.tools += ['save:bqplot_plotlyhist']
BqplotImageView.tools += ['save:bqplot_plotlyimage2d']
BqplotProfileView.tools += ['save:bqplot_plotlyprofile']
BqplotScatterView.tools += ['save:bqplot_plotly2d']
IpyvolumeScatterView.tools = [tool for tool in IpyvolumeScatterView.tools] + ['save:jupyter_plotly3dscatter']
IpyvolumeVolumeView.tools = [tool for tool in IpyvolumeVolumeView.tools] + ['save:jupyter_plotlyvolume']

try:
from glue_vispy_viewers.scatter.jupyter import JupyterVispyScatterViewer
Expand All @@ -99,8 +109,3 @@ def setup_jupyter():
else:
JupyterVispyScatterViewer.tools += ['save:jupyter_plotly3dscatter']
JupyterVispyVolumeViewer.tools += ['save:jupyter_plotlyvolume']

BqplotHistogramView.tools += ['save:bqplot_plotlyhist']
BqplotImageView.tools += ['save:bqplot_plotlyimage2d']
BqplotProfileView.tools += ['save:bqplot_plotlyprofile']
BqplotScatterView.tools += ['save:bqplot_plotly2d']
42 changes: 38 additions & 4 deletions glue_plotly/common/base_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,38 @@ def dimensions(viewer_state):


def projection_type(viewer_state):
return "perspective" if viewer_state.perspective_view else "orthographic"
return "perspective" if getattr(viewer_state, "perspective_view", True) else "orthographic"


def get_resolution(viewer_state):
try:
from glue_vispy_viewers.volume.viewer_state import Vispy3DVolumeViewerState
if isinstance(viewer_state, Vispy3DVolumeViewerState):
return viewer_state.resolution
except ImportError:
pass

try:
from glue_jupyter.common.state3d import VolumeViewerState
if isinstance(viewer_state, VolumeViewerState):
resolutions = tuple(getattr(state, 'max_resolution', None) for state in viewer_state.layers)
return max((res for res in resolutions if res is not None), default=256)
except ImportError:
pass

return 256


# TODO: Update other methods to not rely on these being reversed
def bounds(viewer_state, with_resolution=False):
bds = [(viewer_state.z_min, viewer_state.z_max),
(viewer_state.y_min, viewer_state.y_max),
(viewer_state.x_min, viewer_state.x_max)]
if with_resolution:
resolution = get_resolution(viewer_state)
return [(*b, resolution) for b in bds]

return bds


def axis(viewer_state, ax):
Expand Down Expand Up @@ -84,6 +115,9 @@ def plotly_up_from_vispy(vispy_up):

def layout_config(viewer_state):
width, height, depth = dimensions(viewer_state)
x_stretch = getattr(viewer_state, "x_stretch", 1.)
y_stretch = getattr(viewer_state, "y_stretch", 1.)
z_stretch = getattr(viewer_state, "z_stretch", 1.)
return dict(
margin=dict(r=50, l=50, b=50, t=50), # noqa
width=1200,
Expand All @@ -99,9 +133,9 @@ def layout_config(viewer_state):
# Currently there's no way to change this in glue
up=plotly_up_from_vispy("+z")
),
aspectratio=dict(x=1 * viewer_state.x_stretch,
y=height / width * viewer_state.y_stretch,
z=depth / width * viewer_state.z_stretch),
aspectratio=dict(x=1 * x_stretch,
y=height / width * y_stretch,
z=depth / width * z_stretch),
aspectmode='manual'
)
)
19 changes: 19 additions & 0 deletions glue_plotly/common/scatter3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,21 @@ def error_bar_info(layer_state, mask):
return errs


_IPYVOLUME_GEOMETRY_SYMBOLS = {
"sphere": "circle",
"box": "square",
"diamond": "diamond",
"circle2d": "circle",
}


def symbol_for_geometry(geometry: str) -> str:
symbol = _IPYVOLUME_GEOMETRY_SYMBOLS.get(geometry)
if symbol is not None:
return symbol
raise ValueError(f"Invalid geometry: {geometry}")


def traces_for_layer(viewer_state, layer_state, hover_data=None, add_data_label=True):

x, y, z, mask = clipped_data(viewer_state, layer_state)
Expand All @@ -103,6 +118,10 @@ def traces_for_layer(viewer_state, layer_state, hover_data=None, add_data_label=
opacity=layer_state.alpha,
line=dict(width=0))

if hasattr(layer_state, "geo"):
symbol = symbol_for_geometry(layer_state.geo)
marker["symbol"] = symbol

if hover_data is None or np.sum(hover_data) == 0:
hoverinfo = 'skip'
hovertext = None
Expand Down
20 changes: 3 additions & 17 deletions glue_plotly/common/volume.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from glue_plotly.utils import rgba_components
from glue_plotly.utils import frb_for_layer, rgba_components
from numpy import linspace, meshgrid, nan_to_num, nanmin

from glue.core import BaseData
Expand Down Expand Up @@ -30,23 +30,9 @@ def values(viewer_state, layer_state, bounds, precomputed=None):
parent = layer_state.layer.data if subset_layer else layer_state.layer
parent_label = parent.label
if precomputed is not None and parent_label in precomputed:
data = precomputed[parent_label]
values = precomputed[parent_label]
else:
data = parent.compute_fixed_resolution_buffer(
target_data=viewer_state.reference_data,
bounds=bounds,
target_cid=layer_state.attribute
)

if subset_layer:
subcube = parent.compute_fixed_resolution_buffer(
target_data=viewer_state.reference_data,
bounds=bounds,
subset_state=layer_state.layer.subset_state
)
values = subcube * data
else:
values = data
values = frb_for_layer(viewer_state, layer_state, bounds)

# This accounts for two transformations: the fact that the viewer bounds are in reverse order,
# plus a need to change R -> L handedness for Plotly
Expand Down
4 changes: 2 additions & 2 deletions glue_plotly/html_exporters/jupyter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from . import image # noqa
from . import profile # noqa
from . import scatter2d # noqa
from . import vispy_scatter # noqa
from . import vispy_volume # noqa
from . import scatter3d # noqa
from . import volume # noqa
2 changes: 1 addition & 1 deletion glue_plotly/html_exporters/jupyter/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from glue_jupyter import jglue


class TestBqplotExporter:
class BaseTestJupyterExporter:

viewer_type = None
tool_id = None
Expand Down
6 changes: 3 additions & 3 deletions glue_plotly/html_exporters/jupyter/tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

importorskip('glue_jupyter')

from glue_jupyter.bqplot.histogram import BqplotHistogramView # noqa
from glue_jupyter.bqplot.histogram import BqplotHistogramView # noqa: E402

from .test_base import TestBqplotExporter # noqa
from .test_base import BaseTestJupyterExporter # noqa: E402


class TestHistogram(TestBqplotExporter):
class TestHistogram(BaseTestJupyterExporter):

viewer_type = BqplotHistogramView
tool_id = 'save:bqplot_plotlyhist'
Expand Down
8 changes: 4 additions & 4 deletions glue_plotly/html_exporters/jupyter/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

importorskip('glue_jupyter')

from glue_jupyter.bqplot.image import BqplotImageView # noqa
from glue_jupyter.bqplot.image import BqplotImageView # noqa: E402

from numpy import arange, ones # noqa
from numpy import arange, ones # noqa: E402

from .test_base import TestBqplotExporter # noqa
from .test_base import BaseTestJupyterExporter # noqa: E402


class TestImage(TestBqplotExporter):
class TestImage(BaseTestJupyterExporter):

viewer_type = BqplotImageView
tool_id = 'save:bqplot_plotlyimage2d'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os

from glue.core import Data
from glue_plotly.html_exporters.jupyter.tests.test_base import BaseTestJupyterExporter

from pytest import importorskip

importorskip('glue_jupyter')

from glue_jupyter.ipyvolume import IpyvolumeScatterView # noqa: E402


class TestScatter3D(BaseTestJupyterExporter):

viewer_type = IpyvolumeScatterView
tool_id = 'save:jupyter_plotly3dscatter'

def make_data(self):
return Data(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9], label='d1')

def test_default(self, tmpdir):
output_path = self.export_figure(tmpdir, 'test_default.html')
assert os.path.exists(output_path)
28 changes: 28 additions & 0 deletions glue_plotly/html_exporters/jupyter/tests/test_ipyvolume_volume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os

from glue.core import Data
from glue_plotly.html_exporters.jupyter.tests.test_base import BaseTestJupyterExporter

from pytest import importorskip

importorskip('glue_jupyter')

from glue_jupyter.ipyvolume import IpyvolumeVolumeView # noqa: E402

from numpy import arange, ones # noqa: E402


class TestVolume(BaseTestJupyterExporter):

viewer_type = IpyvolumeVolumeView
tool_id = 'save:jupyter_plotlyvolume'

def make_data(self):
return Data(label='d1',
x=arange(24).reshape((2, 3, 4)),
y=ones((2, 3, 4)),
z=arange(100, 124).reshape((2, 3, 4)))

def test_default(self, tmpdir):
output_path = self.export_figure(tmpdir, 'test_default.html')
assert os.path.exists(output_path)
6 changes: 3 additions & 3 deletions glue_plotly/html_exporters/jupyter/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

importorskip('glue_jupyter')

from glue_jupyter.bqplot.profile import BqplotProfileView # noqa
from glue_jupyter.bqplot.profile import BqplotProfileView # noqa: E402

from .test_base import TestBqplotExporter # noqa
from .test_base import BaseTestJupyterExporter # noqa: E402


class TestProfile(TestBqplotExporter):
class TestProfile(BaseTestJupyterExporter):

viewer_type = BqplotProfileView
tool_id = 'save:bqplot_plotlyprofile'
Expand Down
6 changes: 3 additions & 3 deletions glue_plotly/html_exporters/jupyter/tests/test_scatter2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

importorskip('glue_jupyter')

from glue_jupyter.bqplot.scatter import BqplotScatterView # noqa
from glue_jupyter.bqplot.scatter import BqplotScatterView # noqa: E402

from .test_base import TestBqplotExporter # noqa
from .test_base import BaseTestJupyterExporter # noqa: E402


class TestScatter2D(TestBqplotExporter):
class TestScatter2D(BaseTestJupyterExporter):

viewer_type = BqplotScatterView
tool_id = 'save:bqplot_plotly2d'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

from glue.core import Data
from glue_plotly.html_exporters.jupyter.tests.test_base import TestBqplotExporter
from glue_plotly.html_exporters.jupyter.tests.test_base import BaseTestJupyterExporter

from pytest import importorskip

Expand All @@ -11,7 +11,7 @@
from glue_vispy_viewers.scatter.jupyter import JupyterVispyScatterViewer # noqa: E402


class TestScatter3D(TestBqplotExporter):
class TestScatter3D(BaseTestJupyterExporter):

viewer_type = JupyterVispyScatterViewer
tool_id = 'save:jupyter_plotly3dscatter'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

from glue.core import Data
from glue_plotly.html_exporters.jupyter.tests.test_base import TestBqplotExporter
from glue_plotly.html_exporters.jupyter.tests.test_base import BaseTestJupyterExporter

from pytest import importorskip

Expand All @@ -13,7 +13,7 @@
from numpy import arange, ones # noqa: E402


class TestVolume(TestBqplotExporter):
class TestVolume(BaseTestJupyterExporter):

viewer_type = JupyterVispyVolumeViewer
tool_id = 'save:jupyter_plotlyvolume'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from glue.config import viewer_tool
from glue_vispy_viewers.scatter.layer_artist import ScatterLayerArtist

from glue_plotly.common.base_3d import layout_config
from glue_plotly.common.base_3d import bounds, layout_config
from glue_plotly.common.common import data_count, layers_to_export
from glue_plotly.common.scatter3d import traces_for_layer as scatter3d_traces_for_layer
from glue_plotly.common.volume import traces_for_layer as volume_traces_for_layer
Expand All @@ -26,14 +26,14 @@ def save_figure(self, filepath):

layers = layers_to_export(self.viewer)
add_data_label = data_count(layers) > 1
bounds = self.viewer._vispy_widget._multivol._data_bounds
bds = bounds(self.viewer.state, with_resolution=True)
count = 5
for layer in layers:
if isinstance(layer, ScatterLayerArtist):
traces = scatter3d_traces_for_layer(self.viewer.state, layer.state,
add_data_label=add_data_label)
else:
traces = volume_traces_for_layer(self.viewer.state, layer.state, bounds,
traces = volume_traces_for_layer(self.viewer.state, layer.state, bds,
isosurface_count=count,
add_data_label=add_data_label)

Expand Down
35 changes: 35 additions & 0 deletions glue_plotly/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from re import match, sub

from glue.core import BaseData
from glue.viewers.common.state import LayerState

__all__ = [
'cleaned_labels',
'mpl_ticks_values',
Expand Down Expand Up @@ -93,3 +96,35 @@ def rgba_components(color):
def components_to_hex(r, g, b, a=None):
components = [hex_string(t) for t in (r, g, b, a) if t is not None]
return f"#{''.join(components)}"


def data_for_layer(layer_or_state):
if isinstance(layer_or_state.layer, BaseData):
return layer_or_state.layer
else:
return layer_or_state.layer.data


def frb_for_layer(viewer_state,
layer_or_state,
bounds):

data = data_for_layer(layer_or_state)
layer_state = layer_or_state if isinstance(layer_or_state, LayerState) else layer_or_state.state
is_data_layer = data is layer_or_state.layer
target_data = getattr(viewer_state, 'reference_data', data)
data_frb = data.compute_fixed_resolution_buffer(
target_data=target_data,
bounds=bounds,
target_cid=layer_state.attribute
)

if is_data_layer:
return data_frb
else:
subcube = data.compute_fixed_resolution_buffer(
target_data=target_data,
bounds=bounds,
subset_state=layer_state.layer.subset_state
)
return subcube * data_frb

0 comments on commit 478712c

Please sign in to comment.