|
| 1 | +using Symbolics, SymbolicUtils |
| 2 | +using Graphs, GraphMakie, CairoMakie, LayeredLayouts |
| 3 | +using Printf |
| 4 | + |
| 5 | +function generate_dag(rexs...) |
| 6 | + exs = map(Symbolics.unwrap, rexs) |
| 7 | + dag = SimpleDiGraph() |
| 8 | + vertex_map = IdDict{Any, Int}() |
| 9 | + label_map = Dict{Int, String}() |
| 10 | + unicode_replace = Dict( |
| 11 | + "-" => "−", |
| 12 | + "*" => "×", |
| 13 | + "/" => "÷" |
| 14 | + ) |
| 15 | + function dfs(node) |
| 16 | + if haskey(vertex_map, node) |
| 17 | + return vertex_map[node] |
| 18 | + end |
| 19 | + add_vertex!(dag) |
| 20 | + v = nv(dag) |
| 21 | + vertex_map[node] = v |
| 22 | + if iscall(node) |
| 23 | + op = operation(node) |
| 24 | + args = map(dfs, arguments(node)) |
| 25 | + for arg in args |
| 26 | + add_edge!(dag, v, arg) |
| 27 | + end |
| 28 | + rop = repr(op) |
| 29 | + label_map[v] = get(unicode_replace, rop, rop) |
| 30 | + elseif node isa Number |
| 31 | + label_map[v] = @sprintf "%.1g" node |
| 32 | + else |
| 33 | + label_map[v] = repr(node) |
| 34 | + end |
| 35 | + return v |
| 36 | + end |
| 37 | + for ex in exs |
| 38 | + dfs(ex) |
| 39 | + end |
| 40 | + return dag, label_map |
| 41 | +end |
| 42 | + |
| 43 | +function plot_dag(dag, label_map) |
| 44 | + xs, ys, paths = solve_positions(Zarate(), dag) |
| 45 | + for (key, value) in paths |
| 46 | + paths[key] = (value[2], -value[1]) |
| 47 | + end |
| 48 | + lay = Point.(zip(ys, -xs)) |
| 49 | + wp = [Point2f.(zip(paths[e]...)) for e in edges(dag)] |
| 50 | + fig, ax, p = graphplot( |
| 51 | + dag; layout = lay, ilabels = [label_map[v] for v in 1:nv(dag)], |
| 52 | + waypoints = wp, node_color = :white) |
| 53 | + hidedecorations!(ax) |
| 54 | + hidespines!(ax) |
| 55 | + fig |
| 56 | +end |
| 57 | + |
| 58 | +@variables a b c d |
| 59 | +x = (a + b) * (c + d) |
| 60 | +y = (a - b) * (c - d) |
| 61 | +z = (a + b) * (c - d) |
| 62 | +w = (a - b) * (c + d) |
| 63 | +p = x + y |
| 64 | +q = z - w |
| 65 | +rex = (x + y) / (z - w) |
| 66 | + |
| 67 | +dag, label_map = generate_dag(p, q) |
| 68 | +fig = plot_dag(dag, label_map) |
| 69 | +save("expression_dag.png", fig) |
0 commit comments