Skip to content

Commit

Permalink
Merge pull request #77 from Carifio24/scatter-artist-error-handling
Browse files Browse the repository at this point in the history
Add error handling to scatter layer artist
  • Loading branch information
Carifio24 authored Jul 5, 2024
2 parents 9996c84 + cb727cf commit 9065337
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions glue_plotly/viewers/scatter/layer_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from glue_plotly.common import color_info
from glue_plotly.common.scatter2d import LINESTYLES, rectilinear_lines, scatter_mode, size_info
from glue.core import BaseData
from glue.core.exceptions import IncompatibleAttribute
from glue.utils import ensure_numerical
from glue.viewers.common.layer_artist import LayerArtist
from glue.viewers.scatter.state import ScatterLayerState
Expand Down Expand Up @@ -67,10 +68,6 @@ def __init__(self, view, viewer_state, layer_state=None, layer=None):
layer=layer
)

self._viewer_state.add_global_callback(self._update_display)
self.state.add_global_callback(self._update_display)
self.state.add_callback("zorder", self._update_zorder)

self.view = view

# Somewhat annoyingly, the trace that we pass in to be added
Expand All @@ -93,6 +90,10 @@ def __init__(self, view, viewer_state, layer_state=None, layer=None):
self._error_id = uuid4().hex
self._vector_id = uuid4().hex

self._viewer_state.add_global_callback(self._update_display)
self.state.add_global_callback(self._update_display)
self.state.add_callback("zorder", self._update_zorder)

def remove(self):
self.view._remove_traces([self._get_scatter()])
self.view._remove_traces(self._get_lines())
Expand All @@ -104,7 +105,14 @@ def _get_traces_with_id(self, id):
return self.view.figure.select_traces(dict(meta=id))

def _get_scatter(self):
return next(self._get_traces_with_id(self._scatter_id))
# The scatter trace should always exist
# so if somehow it doesn't, then create it
try:
return next(self._get_traces_with_id(self._scatter_id))
except StopIteration:
scatter = self._create_scatter()
self.view.figure.add_trace(scatter)
return scatter

def _get_lines(self):
return self._get_traces_with_id(self._lines_id)
Expand All @@ -120,8 +128,23 @@ def traces(self):

def _update_data(self):

x = ensure_numerical(self.layer[self._viewer_state.x_att].ravel())
y = ensure_numerical(self.layer[self._viewer_state.y_att].ravel())
try:
x = ensure_numerical(self.layer[self._viewer_state.x_att].ravel())
except (IncompatibleAttribute, IndexError):
if self._viewer_state.x_att is not None:
self.disable_invalid_attributes(self._viewer_state.x_att)
return
else:
self.enable()

try:
y = ensure_numerical(self.layer[self._viewer_state.y_att].ravel())
except (IncompatibleAttribute, IndexError):
if self._viewer_state.y_att is not None:
self.disable_invalid_attributes(self._viewer_state.y_att)
return
else:
self.enable()

scatter = self._get_scatter()
if self._viewer_state.using_rectilinear:
Expand Down

0 comments on commit 9065337

Please sign in to comment.