Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better graph colouring #23

Merged
merged 1 commit into from
Jan 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 81 additions & 17 deletions osaca/semantics/kernel_dg.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,10 @@ def export_graph(self, filepath=None):
for dep in lcd:
lcd_line_numbers[dep] = [x.line_number for x, lat in lcd[dep]["dependencies"]]
# add color scheme
graph.graph["node"] = {"colorscheme": "set312"}
graph.graph["edge"] = {"colorscheme": "set312"}
graph.graph["node"] = {"colorscheme": "spectral9"}
graph.graph["edge"] = {"colorscheme": "spectral9"}
min_color = 2
available_colors = 8

# create LCD edges
for dep in lcd_line_numbers:
Expand All @@ -543,21 +545,14 @@ def export_graph(self, filepath=None):
for n in cp:
graph.nodes[n.line_number]["instruction_form"].latency_cp = n.latency_cp

# color CP and LCD
# Make the critical path bold.
for n in graph.nodes:
if n in cp_line_numbers:
# graph.nodes[n]['color'] = 1
graph.nodes[n]["style"] = "bold"
graph.nodes[n]["penwidth"] = 4
for col, dep in enumerate(lcd):
if n in lcd_line_numbers[dep]:
if "style" not in graph.nodes[n]:
graph.nodes[n]["style"] = "filled"
elif ",filled" not in graph.nodes[n]["style"]:
graph.nodes[n]["style"] += ",filled"
graph.nodes[n]["fillcolor"] = 2 + col % 11

# color edges
# Make critical path edges bold.
for e in graph.edges:
if (
graph.nodes[e[0]]["instruction_form"].line_number in cp_line_numbers
Expand All @@ -571,12 +566,81 @@ def export_graph(self, filepath=None):
if bold_edge:
graph.edges[e]["style"] = "bold"
graph.edges[e]["penwidth"] = 3
for dep in lcd_line_numbers:
if (
graph.nodes[e[0]]["instruction_form"].line_number in lcd_line_numbers[dep]
and graph.nodes[e[1]]["instruction_form"].line_number in lcd_line_numbers[dep]
):
graph.edges[e]["color"] = graph.nodes[e[1]]["fillcolor"]

# Color the cycles created by loop-carried dependencies, longest first, never recoloring
# any node, so that the longest LCD and most long chains that are involved in the loop are
# legible.
for i, dep in enumerate(sorted(lcd, key=lambda dep: -lcd[dep]["latency"])):
# For cycles that are broken by already-colored (longer) cycles, the color need not be
# the same for each yet-uncolored arc.
# Do not use the same color for such an arc as for the cycles that delimit it. This is
# always possible with 3 colors, as each arc is only adjacent to the preceding and
# following interrupting cycles.
# Since we color edges as well as nodes, there would be room for a more interesting
# graph coloring problem: we could avoid having unrelated arcs with the same color
# meeting at the same vertex, and retain the same color between arcs of the same cycle
# that are interrupted by a single vertex. We mostly ignore this problem.

# The longest cycle will always have color 1, the second longest cycle will always have
# color 2 except where it overlaps with with the longest cycle, etc.; for arcs that are
# part of short cycles, the colors will be less predictable.
default_color = min_color + i % available_colors
arc = []
arc_source = lcd_line_numbers[dep][-1]
arcs = []
for n in lcd_line_numbers[dep]:
if "fillcolor" in graph.nodes[n]:
arcs.append((arc, (arc_source, n)))
arc = []
arc_source = n
else:
arc.append(n)
if not arcs: # Unconstrained cycle.
arcs.append((arc, tuple()))
else:
arcs.append((arc, (arc_source, lcd_line_numbers[dep][0])))
# Try to color the whole cycle with its default color, then with a single color, then
# with different colors by arc, preferring the default.
forbidden_colors = set(
graph.nodes[n]["fillcolor"] for arc, extremities in arcs for n in extremities
if "fillcolor" in graph.nodes[n]
)
global_color = None
if default_color not in forbidden_colors:
global_color = default_color
elif len(forbidden_colors) < available_colors:
global_color = next(
c for c in range(min_color, min_color + available_colors + 1)
if c not in forbidden_colors
)
for arc, extremities in arcs:
if global_color:
color = global_color
else:
color = default_color
while color in (graph.nodes[n].get("fillcolor") for n in extremities):
color = min_color + (color + 1) % available_colors
for n in arc:
if "style" not in graph.nodes[n]:
graph.nodes[n]["style"] = "filled"
else:
graph.nodes[n]["style"] += ",filled"
graph.nodes[n]["fillcolor"] = color
if extremities:
(source, sink) = extremities
else:
source = sink = arc[0]
arc = arc[1:]
for u, v in zip([source] + arc, arc + [sink]):
# The backward edge of the cycle is represented as the corresponding forward
# edge with the attribute dir=back.
edge = graph.edges[v, u] if (v, u) in graph.edges else graph.edges[u, v]
if arc:
if "color" in edge:
raise AssertionError(
f"Recoloring {u}->{v} in arc ({source}) {arc} ({sink}) of {dep}"
)
edge["color"] = color

# rename node from [idx] to [idx mnemonic] and add shape
mapping = {}
Expand Down