Skip to content

Commit

Permalink
Merge pull request #16 from Carifio24/errors
Browse files Browse the repository at this point in the history
Initial support for error bars
  • Loading branch information
Carifio24 authored Dec 29, 2023
2 parents 810f54a + 27005d4 commit f99999d
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 48 deletions.
151 changes: 108 additions & 43 deletions glue_ar/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,85 @@ def scatter_layer_as_glyphs(viewer_state, layer_state, glyph):
}


def vector_meshes_for_layer(viewer_state, layer_state,
data, bounds,
tip_resolution=10,
shaft_resolution=10,
mask=None):
atts = [layer_state.vx_attribute, layer_state.vy_attribute, layer_state.vz_attribute]
tip_factor = 0.25 if layer_state.vector_arrowhead else 0
vector_data = [layer_state.layer[att].ravel()[mask] for att in atts]
if viewer_state.native_aspect:
factor = max((abs(b[1] - b[0]) for b in bounds))
vector_data = [[0.5 * t / factor for t in v] for v in vector_data]
else:
bound_factors = [abs(b[1] - b[0]) for b in bounds]
vector_data = [[0.5 * t / b for t in v] for v, b in zip(vector_data, bound_factors)]
vector_data = array(list(zip(*vector_data)))

arrows = []
offset = _VECTOR_OFFSETS[layer_state.vector_origin]
for pt, v in zip(data, vector_data):
adjusted_v = v * layer_state.vector_scaling
length = norm(adjusted_v)
tip_length = tip_factor * length
adjusted_pt = [c + offset * vc for c, vc in zip(pt, adjusted_v)]
arrow = pv.Arrow(
start=adjusted_pt,
direction=v,
shaft_resolution=shaft_resolution,
tip_resolution=tip_resolution,
shaft_radius=0.002,
tip_radius=0.01,
scale=length,
tip_length=tip_length)
arrows.append(arrow)

return arrows


def meshes_for_error_bars(viewer_state, layer_state, axis, data, bounds, mask=None):
att = getattr(layer_state, f"{axis}err_attribute")
err_values = layer_state.layer[att].ravel()[mask]
index = ['x', 'y', 'z'].index(axis)
axis_range = abs(bounds[index][1] - bounds[index][0])
if viewer_state.native_aspect:
max_range = max((abs(b[1] - b[0]) for b in bounds))
factor = 1 / max_range
else:
factor = 1 / axis_range
err_values *= factor
lines = []
for pt, err in zip(data, err_values):
start = [c - err if idx == index else c for idx, c in enumerate(pt)]
end = [c + err if idx == index else c for idx, c in enumerate(pt)]
lines.append(pv.Line(start, end))

return lines


# This function creates a multiblock mesh for a given scatter layer
# Everything is scaled into clip space for better usability with e.g. model-viewer
def scatter_layer_as_multiblock(viewer_state, layer_state,
theta_resolution=8,
phi_resolution=8,
clip_to_bounds=True,
scaled=True):
meshes = []
bounds = xyz_bounds(viewer_state)
if clip_to_bounds:
mask = mask_for_bounds(viewer_state, layer_state, bounds)
else:
mask = None

fixed_color = layer_state.color_mode == "Fixed"
data = xyz_for_layer(viewer_state, layer_state,
preserve_aspect=viewer_state.native_aspect,
mask=mask,
scaled=scaled)
factor = max((abs(b[1] - b[0]) for b in bounds))
if layer_state.size_mode == "Fixed":
radius = layer_state.size_scaling * sqrt((layer_state.size)) / (7 * factor)
radius = layer_state.size_scaling * sqrt((layer_state.size)) / (10 * factor)
spheres = [pv.Sphere(center=p, radius=radius,
phi_resolution=phi_resolution,
theta_resolution=theta_resolution) for p in data]
Expand All @@ -78,59 +138,36 @@ def scatter_layer_as_multiblock(viewer_state, layer_state,
spheres = [pv.Sphere(center=p, radius=r,
phi_resolution=phi_resolution,
theta_resolution=theta_resolution) for p, r in zip(data, sizes)]

if not fixed_color:
points_per_sphere = 2 + (phi_resolution - 2) * theta_resolution
cmap_values = layer_state.layer[layer_state.cmap_attribute][mask]
point_cmap_values = [y for x in cmap_values for y in (x,) * points_per_sphere]

blocks = pv.MultiBlock(spheres)

# Create the meshes for vectors, if necessary
if layer_state.vector_visible:
tip_resolution = 10
shaft_resolution = 10
atts = [layer_state.vx_attribute, layer_state.vy_attribute, layer_state.vz_attribute]
tip_factor = 0.25 if layer_state.vector_arrowhead else 0
vector_data = [layer_state.layer[att].ravel()[mask] for att in atts]
if viewer_state.native_aspect:
factor = max((abs(b[1] - b[0]) for b in bounds))
vector_data = [[0.5 * t / factor for t in v] for v in vector_data]
else:
bound_factors = [abs(b[1] - b[0]) for b in bounds]
vector_data = [[0.5 * t / b for t in v] for v, b in zip(vector_data, bound_factors)]
vector_data = array(list(zip(*vector_data)))

arrows = []
offset = _VECTOR_OFFSETS[layer_state.vector_origin]
for pt, v in zip(data, vector_data):
adjusted_v = v * layer_state.vector_scaling
length = norm(adjusted_v)
tip_length = tip_factor * length
adjusted_pt = [c + offset * vc for c, vc in zip(pt, adjusted_v)]
arrow = pv.Arrow(
start=adjusted_pt,
direction=v,
shaft_resolution=shaft_resolution,
tip_resolution=tip_resolution,
shaft_radius=0.002,
tip_radius=0.01,
scale=length,
tip_length=tip_length)
arrows.append(arrow)

tip_resolution = 10
arrows = vector_meshes_for_layer(viewer_state, layer_state,
data, bounds,
tip_resolution=tip_resolution,
shaft_resolution=shaft_resolution,
mask=mask)
points_per_arrow = (4 * shaft_resolution) + tip_resolution + 1
point_cmap_values.extend([y for x in cmap_values for y in (x,) * points_per_arrow])
blocks.extend(arrows)

# Note:
# each arrow has (4 * shaft_resolution) + tip_resolution + 1 points


geometry = blocks.extract_geometry()
info = {
"mesh": geometry,
"opacity": layer_state.alpha
}
if layer_state.color_mode == "Fixed":
meshes.append(info)
if fixed_color:
info["color"] = layer_color(layer_state)
else:
sphere_points = 2 + (phi_resolution - 2) * theta_resolution # The number of points on each sphere
cmap_values = layer_state.layer[layer_state.cmap_attribute][mask]
point_cmap_values = [y for x in cmap_values for y in (x,) * sphere_points]
if layer_state.vector_visible:
arrow_points = (4 * shaft_resolution) + tip_resolution + 1
point_cmap_values.extend([y for x in cmap_values for y in (x,) * arrow_points])
geometry.point_data["colors"] = point_cmap_values
cmap = layer_state.cmap.name # This assumes that we're using a matplotlib colormap
clim = [layer_state.cmap_vmin, layer_state.cmap_vmax]
Expand All @@ -141,4 +178,32 @@ def scatter_layer_as_multiblock(viewer_state, layer_state,
info["clim"] = clim
info["scalars"] = "colors"

return info
# Add error bars
# We make these their own mesh because (for some reason) they disrupt the coloring of the
# points and arrows if they're together in one MultiBlock
# TODO: Why is this?
if any((layer_state.xerr_visible, layer_state.yerr_visible, layer_state.zerr_visible)):
bars = pv.MultiBlock()
bars_info = {}
bars_cmap_values = []
for axis in ['x', 'y', 'z']:
if getattr(layer_state, f"{axis}err_visible"):
axis_bars = meshes_for_error_bars(viewer_state, layer_state,
axis, data, bounds, mask=mask)
bars.extend(axis_bars)
if not fixed_color:
bars_cmap_values.extend([y for x in cmap_values for y in (x,) * 2]) # Each line has just two points

bars_geometry = bars.extract_geometry()
bars_info["mesh"] = bars_geometry
if fixed_color:
bars_info["color"] = layer_color(layer_state)
else:
bars_geometry.point_data["colors"] = bars_cmap_values
bars_info["cmap"] = cmap
bars_info["clim"] = clim
bars_info["scalars"] = "colors"

meshes.append(bars_info)

return meshes
9 changes: 5 additions & 4 deletions glue_ar/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ def activate(self):
layer_states = [layer.state for layer in self.viewer.layers if layer.enabled and layer.state.visible]
for layer_state in layer_states:
layer_info = dialog.state_dictionary[layer_state.layer.label].as_dict()
mesh_info = scatter_layer_as_multiblock(self.viewer.state, layer_state,
meshes = scatter_layer_as_multiblock(self.viewer.state, layer_state,
scaled=True,
clip_to_bounds=self.viewer.state.clip_data,
**layer_info)
mesh = mesh_info.pop("mesh")
if len(mesh.points) > 0:
plotter.add_mesh(mesh, **mesh_info)
for mesh_info in meshes:
mesh = mesh_info.pop("mesh")
if len(mesh.points) > 0:
plotter.add_mesh(mesh, **mesh_info)

dir, base = split(export_path)
name, ext = splitext(base)
Expand Down
1 change: 0 additions & 1 deletion glue_ar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def bounds_3d_from_layers(viewer_state, layer_states):
data = state.layer.layer
mins = [min(min(data[att]), m) for m, att in zip(mins, atts)]
maxes = [max(max(data[att]), m) for m, att in zip(maxes, atts)]
print(mins)
return [(l, u) for l, u in zip(mins, maxes)]


Expand Down

0 comments on commit f99999d

Please sign in to comment.