Skip to content

Commit

Permalink
refactor(analysis): add experiment 3d visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
yzx9 committed Apr 5, 2024
1 parent 90030e3 commit 3a92dcc
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
85 changes: 85 additions & 0 deletions swcgeom/analysis/visualization3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Painter utils.
Notes
-----
This is a experimental function, it may be changed in the future.
"""

from typing import Dict, Optional

import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from mpl_toolkits.mplot3d import Axes3D

from swcgeom.analysis.visualization import (
_set_ax_memo,
get_ax_color,
get_ax_swc,
set_ax_legend,
)
from swcgeom.core import SWCLike, Tree
from swcgeom.utils.plotter_3d import draw_lines_3d

__all__ = ["draw3d"]


# TODO: support Camera
def draw3d(
swc: SWCLike | str,
*,
ax: Axes,
show: bool | None = None,
color: Optional[Dict[int, str] | str] = None,
label: str | bool = True,
**kwargs,
) -> tuple[Figure, Axes]:
r"""Draw neuron tree.
Parameters
----------
swc : SWCLike | str
If it is str, then it is treated as the path of swc file.
fig : ~matplotlib.axes.Figure, optional
ax : ~matplotlib.axes.Axes, optional
show : bool | None, default `None`
Wheather to call `plt.show()`. If not specified, it will depend
on if ax is passed in, it will not be called, otherwise it will
be called by default.
color : Dict[int, str] | "vaa3d" | str, optional
Color map. If is dict, segments will be colored by the type of
parent node.If is string, the value will be use for any type.
label : str | bool, default True
Label of legend, disable if False.
**kwargs : dict[str, Unknown]
Forwarded to `~mpl_toolkits.mplot3d.art3d.Line3DCollection`.
"""

assert isinstance(ax, Axes3D), "only support 3D axes."

swc = Tree.from_swc(swc) if isinstance(swc, str) else swc

show = (show is True) or (show is None and ax is None)
my_color = get_ax_color(ax, swc, color) # type: ignore

xyz = swc.xyz()
starts, ends = swc.id()[1:], swc.pid()[1:]
lines = np.stack([xyz[starts], xyz[ends]], axis=1)
collection = draw_lines_3d(ax, lines, color=my_color, **kwargs)

min_vals = lines.reshape(-1, 3).min(axis=0)
max_vals = lines.reshape(-1, 3).max(axis=0)
ax.set_xlim(min_vals[0], max_vals[0])
ax.set_ylim(min_vals[1], max_vals[1])
ax.set_zlim(min_vals[2], max_vals[2])

_set_ax_memo(ax, swc, label=label, handle=collection)

if len(get_ax_swc(ax)) == 1:
# ax.set_aspect(1)
ax.spines[["top", "right"]].set_visible(False)
else:
set_ax_legend(ax, loc="upper right") # enable legend

fig = ax.figure
return fig, ax # type: ignore
31 changes: 31 additions & 0 deletions swcgeom/utils/plotter_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""3D Plotting utils."""

import numpy as np
import numpy.typing as npt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Line3DCollection

__all__ = ["draw_lines_3d"]


def draw_lines_3d(
ax: Axes3D,
lines: npt.NDArray[np.floating],
joinstyle="round",
capstyle="round",
**kwargs,
):
"""Draw lines.
Parameters
----------
ax : ~matplotlib.axes.Axes
lines : A collection of coords of lines
Excepting a ndarray of shape (N, 2, 3), the axis-2 holds two points,
and the axis-3 holds the coordinates (x, y, z).
**kwargs : dict[str, Unknown]
Forwarded to `~mpl_toolkits.mplot3d.art3d.Line3DCollection`.
"""

line_collection = Line3DCollection(lines, joinstyle=joinstyle, capstyle=capstyle, **kwargs) # type: ignore
return ax.add_collection3d(line_collection)

0 comments on commit 3a92dcc

Please sign in to comment.