Skip to content

Commit

Permalink
Updating bokeh code to support 3.4.0+ (#826)
Browse files Browse the repository at this point in the history
* Updating bokeh code to support 3.4.0+

* Adjusting bokeh requirements to support Python 3.8 (bokeh >=3.0.0)

* Moving more Bokeh deprecated methods to new format.

* Changing panel requirement to support Python 3.8

* Adding myp suppressions for some Bokeh types in entity_graph_tools.py and network_plot.py

* Changing version check to Florian's more elegant version
  • Loading branch information
ianhelle authored Feb 11, 2025
1 parent 93411e7 commit a4b0b72
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 35 deletions.
4 changes: 2 additions & 2 deletions conda/conda-reqs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ azure-mgmt-resource>=16.1.0
azure-storage-blob>=12.5.0
azure-mgmt-subscription
beautifulsoup4>=4.0.0
bokeh>=1.4.0, <3.4.0
bokeh>=3.0.0
cryptography>=3.1
deprecated>=1.2.4
dnspython>=2.0.0, <3.0.0
Expand All @@ -33,7 +33,7 @@ msrestazure>=0.6.0
networkx>=2.2
numpy>=1.15.4
pandas>=1.4.0, <3.0.0
panel>=0.14.4
panel>=1.2.1
pydantic>=1.8.0, <3.0.0
pygments>=2.0.0
pyjwt>=2.3.0
Expand Down
24 changes: 20 additions & 4 deletions msticpy/vis/entity_graph_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# license information.
# --------------------------------------------------------------------------
"""Creates an entity graph for a Microsoft Sentinel Incident."""

from datetime import datetime, timezone
from importlib.metadata import version
from typing import List, Optional, Union

import networkx as nx
Expand All @@ -15,6 +17,7 @@
from bokeh.models import Circle, HoverTool, Label, LayoutDOM # type: ignore
from bokeh.plotting import figure, from_networkx
from dateutil import parser
from packaging.version import Version, parse

from .._version import VERSION
from ..common.exceptions import MsticpyUserError
Expand All @@ -33,6 +36,8 @@
req_alert_cols = ["DisplayName", "Severity", "AlertType"]
req_inc_cols = ["id", "name", "properties.severity"]

_BOKEH_VERSION: Version = parse(version("bokeh"))

# wrap figure function to handle v2/v3 parameter renaming
figure = bokeh_figure(figure) # type: ignore[assignment, misc]

Expand Down Expand Up @@ -508,15 +513,26 @@ def plot_entitygraph( # pylint: disable=too-many-locals
graph_renderer = from_networkx(
entity_graph_for_plotting, nx.spring_layout, scale=scale, center=(0, 0)
)
if _BOKEH_VERSION > Version("3.2.0"):
circle_parms = {
"radius": node_size // 2,
"fill_color": "node_color",
"fill_alpha": 0.5,
}
else:
circle_parms = {
"size": node_size,
"fill_color": "node_color",
"fill_alpha": 0.5,
}
graph_renderer.node_renderer.glyph = Circle(**circle_parms) # type: ignore[attr-defined]

graph_renderer.node_renderer.glyph = Circle(
size=node_size, fill_color="node_color", fill_alpha=0.5
)
# pylint: disable=no-member
plot.renderers.append(graph_renderer) # type: ignore[attr-defined]

# Create labels
for index, pos in graph_renderer.layout_provider.graph_layout.items():
label_layout = graph_renderer.layout_provider.graph_layout # type: ignore[attr-defined]
for index, pos in label_layout.items():
label = Label(
x=pos[0],
y=pos[1],
Expand Down
5 changes: 3 additions & 2 deletions msticpy/vis/matrix_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def plot_matrix(data: pd.DataFrame, **kwargs) -> LayoutDOM:
plot.add_tools(HoverTool(tooltips=tool_tips))

if param.intersect:
plot.circle_cross(
plot.scatter(
marker="circle_cross",
x=param.x_column,
y=param.y_column,
source=source,
Expand All @@ -196,7 +197,7 @@ def plot_matrix(data: pd.DataFrame, **kwargs) -> LayoutDOM:
size=5,
)
else:
plot.circle(
plot.scatter(
x=param.x_column,
y=param.y_column,
source=source,
Expand Down
35 changes: 26 additions & 9 deletions msticpy/vis/network_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# license information.
# --------------------------------------------------------------------------
"""Module for common display functions."""

from importlib.metadata import version
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import networkx as nx
Expand All @@ -22,6 +24,7 @@
)
from bokeh.palettes import Spectral4
from bokeh.plotting import figure, from_networkx, show
from packaging.version import Version, parse
from typing_extensions import Literal

from .._version import VERSION
Expand All @@ -30,6 +33,8 @@
__version__ = VERSION
__author__ = "Ian Hellen"

_BOKEH_VERSION: Version = parse(version("bokeh"))

# wrap figure function to handle v2/v3 parameter renaming
figure = bokeh_figure(figure) # type: ignore[assignment, misc]

Expand Down Expand Up @@ -173,8 +178,8 @@ def plot_nx_graph(
_create_edge_renderer(graph_renderer, edge_color=edge_color)
_create_node_renderer(graph_renderer, node_size, "node_color")

graph_renderer.selection_policy = NodesAndLinkedEdges()
graph_renderer.inspection_policy = EdgesAndLinkedNodes()
graph_renderer.selection_policy = NodesAndLinkedEdges() # type: ignore[assignment]
graph_renderer.inspection_policy = EdgesAndLinkedNodes() # type: ignore[assignment]
# pylint: disable=no-member
plot.renderers.append(graph_renderer) # type: ignore[attr-defined]

Expand All @@ -189,7 +194,8 @@ def plot_nx_graph(

# Create labels
# pylint: disable=no-member
for index, pos in graph_renderer.layout_provider.graph_layout.items():
label_layout = graph_renderer.layout_provider.graph_layout # type: ignore[attr-defined]
for index, pos in label_layout.items():
label = Label(
x=pos[0],
y=pos[1],
Expand Down Expand Up @@ -242,12 +248,18 @@ def _create_edge_hover(

def _create_node_renderer(graph_renderer: Renderer, node_size: int, fill_color: str):
"""Create graph render for nodes."""
graph_renderer.node_renderer.glyph = Circle(size=node_size, fill_color=fill_color)
if _BOKEH_VERSION > Version("3.2.0"):
circle_size_param = {"radius": node_size // 2}
else:
circle_size_param = {"size": node_size // 2}
graph_renderer.node_renderer.glyph = Circle(
**circle_size_param, fill_color=fill_color
)
graph_renderer.node_renderer.hover_glyph = Circle(
size=node_size, fill_color=Spectral4[1]
**circle_size_param, fill_color=Spectral4[1]
)
graph_renderer.node_renderer.selection_glyph = Circle(
size=node_size, fill_color=Spectral4[2]
**circle_size_param, fill_color=Spectral4[2]
)


Expand Down Expand Up @@ -331,14 +343,19 @@ def plot_entity_graph(
graph_renderer = from_networkx(
entity_graph, nx.spring_layout, scale=scale, center=(0, 0)
)
graph_renderer.node_renderer.glyph = Circle(
size=node_size, fill_color="node_color", fill_alpha=0.5
if _BOKEH_VERSION > Version("3.2.0"):
circle_size_param = {"radius": node_size // 2}
else:
circle_size_param = {"size": node_size // 2}
graph_renderer.node_renderer.glyph = Circle( # type: ignore[attr-defined]
**circle_size_param, fill_color="node_color", fill_alpha=0.5
)
# pylint: disable=no-member
plot.renderers.append(graph_renderer) # type: ignore[attr-defined]

# Create labels
for name, pos in graph_renderer.layout_provider.graph_layout.items():
label_layout = graph_renderer.layout_provider.graph_layout # type: ignore[attr-defined]
for name, pos in label_layout.items():
label = Label(
x=pos[0],
y=pos[1],
Expand Down
6 changes: 4 additions & 2 deletions msticpy/vis/timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ def _plot_series(data, plot, legend_pos):
if "time_column" in series_def:
time_col = series_def["time_column"]
if legend_pos == "inline":
p_series = plot.diamond(
p_series = plot.scatter(
marker="diamond",
x=time_col,
y="y_index",
color=series_def["color"],
Expand All @@ -361,7 +362,8 @@ def _plot_series(data, plot, legend_pos):
legend_label=str(ser_name),
)
else:
p_series = plot.diamond(
p_series = plot.scatter(
marker="diamond",
x=time_col,
y="y_index",
color=series_def["color"],
Expand Down
4 changes: 2 additions & 2 deletions msticpy/vis/timeline_duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

# pylint: enable=unused-import


__version__ = VERSION
__author__ = "Ian Hellen"

Expand Down Expand Up @@ -211,7 +210,8 @@ def display_timeline_duration(
plot.rect(x="Center", y=dodge("Row", 0.5), width="Width", **rect_plot_params)

# Plot the individual events as diamonds
plot.diamond(
plot.scatter(
marker="diamond",
x=time_column,
y=dodge("Row", 0.5),
color=param.color,
Expand Down
4 changes: 2 additions & 2 deletions msticpy/vis/timeline_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def display_timeline_values( # noqa: C901, MC0001
if "vbar" in plot_kinds:
plot.vbar(top=value_col, width=4, **plot_args)
if "circle" in plot_kinds:
plot.circle(y=value_col, size=4, **plot_args)
plot.scatter(y=value_col, size=2, **plot_args)
if "line" in plot_kinds:
plot.line(y=value_col, line_width=2, **plot_args)

Expand Down Expand Up @@ -341,7 +341,7 @@ def _plot_param_group(
)
if "circle" in plot_kinds:
p_series.append(
plot.circle(y=value_col, size=4, color="color", **plot_args)
plot.circle(y=value_col, radius=2, color="color", **plot_args)
)
if "line" in plot_kinds:
p_series.append(
Expand Down
12 changes: 6 additions & 6 deletions msticpy/vis/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ def display_timeseries_anomalies(
plot.xaxis[0].formatter = get_tick_formatter()
plot.yaxis.formatter = NumeralTickFormatter(format="00")

plot.circle(
plot.scatter(
time_column,
value_column,
line_color=color[0],
size=4,
size=2,
source=source,
legend_label="observed",
)
Expand All @@ -231,13 +231,13 @@ def display_timeseries_anomalies(

# setting the visualization types for anomalies based on user input to kind
if kind == "cross":
plot.cross(**arg_dict)
plot.scatter(marker="cross", **arg_dict)
elif kind == "diamond":
plot.diamond(**arg_dict)
plot.scatter(marker="diamond", **arg_dict)
elif kind == "diamond_cross":
plot.diamond_cross(**arg_dict)
plot.scatter(marker="diamond_cross", **arg_dict)
else:
plot.circle_x(**arg_dict)
plot.scatter(marker="circle_x", **arg_dict)

# interactive legend to hide single/multiple plots if selected
plot.legend.location = legend_pos
Expand Down
4 changes: 2 additions & 2 deletions requirements-all.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ azure-mgmt-subscription>=3.0.0
azure-monitor-query>=1.0.0, <=2.0.0
azure-storage-blob>=12.5.0
beautifulsoup4>=4.0.0
bokeh>=1.4.0, <3.4.0
bokeh>=3.0.0
cryptography>=3.1
deprecated>=1.2.4
dnspython>=2.0.0, <3.0.0
Expand All @@ -42,7 +42,7 @@ numpy>=1.15.4 # pandas
openpyxl>=3.0
packaging>=24.0
pandas>=1.4.0, <3.0.0
panel>=0.14.4
panel>=1.2.1
passivetotal>=2.5.3
autogen-agentchat[retrievechat]~=0.2.0
pydantic>=1.8.0, <3.0.0
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ azure-mgmt-keyvault>=2.0.0
azure-mgmt-subscription>=3.0.0
azure-monitor-query>=1.0.0, <=2.0.0
beautifulsoup4>=4.0.0
bokeh>=1.4.0, <3.4.0
bokeh>=3.0.0
cryptography>=3.1
deprecated>=1.2.4
dnspython>=2.0.0, <3.0.0
Expand All @@ -31,6 +31,7 @@ networkx>=2.2
numpy>=1.15.4 # pandas
packaging>=24.0
pandas>=1.4.0, <3.0.0
panel>=1.2.1
pydantic>=1.8.0, <3.0.0
pygments>=2.0.0
pyjwt>=2.3.0
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def _combine_extras(extras: list) -> list:
"azure-storage-blob>=12.5.0",
"azure-mgmt-resourcegraph>=8.0.0",
],
"azure_query": [],
"keyvault": [],
"azure_query": [], # now in core install
"keyvault": [], # now in core install
"ml": [
"scikit-learn>=1.0.0",
"scipy>=1.1.0",
Expand All @@ -52,7 +52,7 @@ def _combine_extras(extras: list) -> list:
],
"sql2kql": ["mo-sql-parsing>=8, <9.0.0"],
"riskiq": ["passivetotal>=2.5.3", "requests>=2.31.0"],
"panel": ["panel>=0.14.4"],
"panel": [], # now in core install
"aiagents": ["autogen-agentchat[retrievechat]~=0.2.0"],
}
extras_all = [
Expand Down
5 changes: 5 additions & 0 deletions tools/create_reqs_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# --------------------------------------------------------------------------
"""Requirements file writer from setup.py extras."""

import argparse
import difflib
import sys
Expand Down Expand Up @@ -206,6 +207,10 @@ def _get_extras_from_setup(
if args.diff:
# If we just wanted to check for a diff, finish here
if diff_reqs:
print(
"Differences found for setup.py + requirements.txt",
"vs. requirements-all.txt",
)
print("\n".join(diff.strip() for diff in diff_reqs))
sys.exit(1)
print("No differences for requirements-all.txt")
Expand Down

0 comments on commit a4b0b72

Please sign in to comment.