Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from cubed.backend_array_api import namespace as nxp
from cubed.core.array import CoreArray
from cubed.core.ops import elemwise
from cubed.diagnostics.widgets import get_template
from cubed.utils import itemsize, memory_repr
from cubed.vendor.dask.widgets import get_template

ARRAY_SVG_SIZE = (
120 # cubed doesn't have a config module like dask does so hard-code this for now
Expand Down
6 changes: 5 additions & 1 deletion cubed/core/array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from operator import mul
from typing import Optional, TypeVar
from typing import Literal, Optional, TypeVar

from toolz import map, reduce

Expand Down Expand Up @@ -205,6 +205,7 @@ def visualize(
optimize_graph=True,
optimize_function=None,
show_hidden=False,
engine: Literal["cytoscape", "graphviz"] | None = None,
):
"""Produce a visualization of the computation graph for this array.

Expand Down Expand Up @@ -238,6 +239,7 @@ def visualize(
optimize_graph=optimize_graph,
optimize_function=optimize_function,
show_hidden=show_hidden,
engine=engine,
)

def __getitem__(self: T_ChunkedArray, key, /) -> T_ChunkedArray:
Expand Down Expand Up @@ -329,6 +331,7 @@ def visualize(
optimize_graph=True,
optimize_function=None,
show_hidden=False,
engine: Literal["cytoscape", "graphviz"] | None = None,
):
"""Produce a visualization of the computation graph for multiple arrays.

Expand Down Expand Up @@ -366,6 +369,7 @@ def visualize(
filename=filename,
format=format,
show_hidden=show_hidden,
engine=engine,
)


Expand Down
36 changes: 34 additions & 2 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime
from enum import Enum
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

import networkx as nx

Expand Down Expand Up @@ -609,7 +609,16 @@ def visualize(
format=None,
rankdir="TB",
show_hidden=False,
engine: Literal["cytoscape", "graphviz"] | None = None,
):
if engine == "cytoscape":
return self.visualize_cytoscape(
filename,
format=format,
rankdir=rankdir,
show_hidden=show_hidden,
)

if self._ops_exceeding_memory:
op_names = [name for name, _ in self._ops_exceeding_memory]
warnings.warn(
Expand Down Expand Up @@ -672,7 +681,8 @@ def visualize(
stacks.append(stack_summaries)
# add current stack info
# go back one in the stack to the caller of 'visualize'
frame = inspect.currentframe().f_back
frame = inspect.currentframe()
frame = frame.f_back if frame is not None else frame
stack_summaries = extract_stack_summaries(frame, limit=10)
stacks.append(stack_summaries)
array_display_names = extract_array_names_from_stack_summaries(stacks)
Expand Down Expand Up @@ -800,6 +810,28 @@ def visualize(
pass
return None

def visualize_cytoscape(
self,
filename="cubed",
format=None,
rankdir="TB",
show_hidden=False,
):
from cubed.diagnostics.widgets.plan import create_or_update_plan_widget

widget = create_or_update_plan_widget(self, rankdir=rankdir)

if filename is not None:
from ipywidgets.embed import embed_minimal_html

if format is None:
format = "html"
full_filename = f"{filename}.{format}"
embed_minimal_html(
full_filename, views=[widget], title="Cubed plan", drop_defaults=False
)
return widget


@dataclasses.dataclass
class ComputeStartEventWithPlan(ComputeStartEvent):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
try:
from cubed.vendor.dask.widgets.widgets import (
from cubed.diagnostics.widgets.core import (
FILTERS,
TEMPLATE_PATHS,
get_environment,
get_template,
)

from .plan import LivePlanViewer, PlanWidget

__all__ = ["LivePlanViewer", "PlanWidget"]

except ImportError as e:
msg = (
"Cubed diagnostics requirements are not installed.\n\n"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,15 @@
import datetime
import html
import os.path

from jinja2 import Environment, FileSystemLoader, Template
from jinja2.exceptions import TemplateNotFound

from cubed.vendor.dask.utils import (
format_bytes,
format_time,
format_time_ago,
key_split,
typename,
)
from cubed.utils import format_int, memory_repr

FILTERS = {
"datetime_from_timestamp": datetime.datetime.fromtimestamp,
"format_bytes": format_bytes,
"format_time": format_time,
"format_time_ago": format_time_ago,
"format_int": format_int,
"html_escape": html.escape,
"key_split": key_split,
"type": type,
"typename": typename,
"memory_repr": memory_repr,
}

TEMPLATE_PATHS = [os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")]
Expand All @@ -31,14 +19,11 @@ def get_environment() -> Environment:
loader = FileSystemLoader(TEMPLATE_PATHS)
environment = Environment(loader=loader)
environment.filters.update(FILTERS)

return environment


def get_template(name: str) -> Template:
try:
return get_environment().get_template(name)
except TemplateNotFound as e:
raise TemplateNotFound(
f"Unable to find {name} in dask.widgets.TEMPLATE_PATHS {TEMPLATE_PATHS}"
) from e
raise TemplateNotFound(f"Unable to find {name} in {TEMPLATE_PATHS}") from e
Loading
Loading