diff --git a/Project.toml b/Project.toml index 8a0c133..5952890 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,7 @@ LuxorTensorPlot = ["LuxorGraphPlot"] [compat] AbstractTrees = "0.3, 0.4" Aqua = "0.8" -CliqueTrees = "1.5.0" +CliqueTrees = "1.12.1" DataStructures = "0.18" Documenter = "1.10.1" Graphs = "1" diff --git a/src/OMEinsumContractionOrders.jl b/src/OMEinsumContractionOrders.jl index a00387e..e44c979 100644 --- a/src/OMEinsumContractionOrders.jl +++ b/src/OMEinsumContractionOrders.jl @@ -10,7 +10,7 @@ using TreeWidthSolver using TreeWidthSolver.Graphs using DataStructures: PriorityQueue, enqueue!, dequeue!, peek, dequeue_pair! import CliqueTrees -using CliqueTrees: cliquetree, residual, EliminationAlgorithm, MMW, BFS, MCS, LexBFS, RCMMD, RCMGL, MCSM, LexM, AMF, MF, MMD, MF, BT, SafeRules, KaHyParND, METISND, ND, BestWidth +using CliqueTrees: cliquetree, cliquetree!, separator, residual, CliqueTree, EliminationAlgorithm, MMW, BFS, MCS, LexBFS, RCMMD, RCMGL, MCSM, LexM, AMF, MF, MMD, MF, BT, SafeRules, KaHyParND, METISND, ND, BestWidth, ConnectedComponents # interfaces export simplify_code, optimize_code, slice_code, optimize_permute, label_elimination_order, uniformsize, ScoreFunction diff --git a/src/hypernd.jl b/src/hypernd.jl index 6fbe07a..2234ad2 100644 --- a/src/hypernd.jl +++ b/src/hypernd.jl @@ -4,6 +4,7 @@ algs = (MF(), AMF(), MMD()), level = 6, width = 120, + scale = 100, imbalances = 130:130, score = ScoreFunction(), ) @@ -11,7 +12,11 @@ Nested-dissection based optimizer. Recursively partitions a tensor network, then calls a greedy algorithm on the leaves. The optimizer is run a number of times: once for each greedy algorithm in `algs` and each imbalance value in `imbalances`. The recursion depth is controlled by -the parameters `level` and `width`. +the parameters `level` and `width`. The parameter `scale` controls discretization of the index weights: + + weight(i) := scale * log2(dim(i)) + +where dim(i) is the dimension of the index i. The line graph is partitioned using the algorithm `dis`. OMEinsumContractionOrders currently supports two partitioning algorithms, both of which require importing an external library. @@ -37,16 +42,18 @@ The optimizer is implemented using the tree decomposition library dis::D = KaHyParND() algs::A = (MF(), AMF(), MMD()) level::Int = 6 - width::Int = 120 - imbalances::StepRange{Int, Int} = 130:1:130 + width::Int = 50 + scale::Int = 100 + imbalances::StepRange{Int, Int} = 100:10:800 score::ScoreFunction = ScoreFunction() end -function optimize_hyper_nd(optimizer::HyperND, code, size_dict) +function optimize_hyper_nd(optimizer::HyperND, code::AbstractEinsum, size_dict::AbstractDict; binary::Bool=true) dis = optimizer.dis algs = optimizer.algs level = optimizer.level width = optimizer.width + scale = optimizer.scale imbalances = optimizer.imbalances score = optimizer.score @@ -54,9 +61,9 @@ function optimize_hyper_nd(optimizer::HyperND, code, size_dict) local mincode for imbalance in imbalances - curalg = SafeRules(ND(BestWidth(algs), dis; level, width, imbalance)) - curoptimizer = Treewidth(; alg=curalg) - curcode = _optimize_code(code, size_dict, curoptimizer) + curalg = SafeRules(ND(BestWidth(algs), dis; level, width, scale, imbalance)) + curopt = Treewidth(; alg=curalg) + curcode = optimize_treewidth(curopt, code, size_dict; binary=false) curtc, cursc, currw = __timespacereadwrite_complexity(curcode, size_dict) if score(curtc, cursc, currw) < minscore @@ -64,6 +71,10 @@ function optimize_hyper_nd(optimizer::HyperND, code, size_dict) end end + if binary + mincode = _optimize_code(mincode, size_dict, GreedyMethod()) + end + return mincode end @@ -77,7 +88,8 @@ function Base.show(io::IO, ::MIME"text/plain", optimizer::HyperND{D, A}) where { println(io, " level: $(optimizer.level)") println(io, " width: $(optimizer.width)") + println(io, " scale: $(optimizer.scale)") println(io, " imbalances: $(optimizer.imbalances)") - println(io, " target: $(optimizer.target)") + println(io, " score: $(optimizer.score)") return end diff --git a/src/treewidth.jl b/src/treewidth.jl index e95bb51..d72b2a9 100644 --- a/src/treewidth.jl +++ b/src/treewidth.jl @@ -6,13 +6,6 @@ Tree width based solver. The solvers are implemented in [CliqueTrees.jl](https:/ | Algorithm | Description | Time Complexity | Space Complexity | |:-----------|:-------------|:----------------|:-----------------| -| `BFS` | breadth-first search | O(m + n) | O(n) | -| `MCS` | maximum cardinality search | O(m + n) | O(n) | -| `LexBFS` | lexicographic breadth-first search | O(m + n) | O(m + n) | -| `RCMMD` | reverse Cuthill-Mckee (minimum degree) | O(m + n) | O(m + n) | -| `RCMGL` | reverse Cuthill-Mckee (George-Liu) | O(m + n) | O(m + n) | -| `MCSM` | maximum cardinality search (minimal) | O(mn) | O(n) | -| `LexM` | lexicographic breadth-first search (minimal) | O(mn) | O(n) | | `AMF` | approximate minimum fill | O(mn) | O(m + n) | | `MF` | minimum fill | O(mn²) | - | | `MMD` | multiple minimum degree | O(mn²) | O(m + n) | @@ -39,15 +32,15 @@ Dict{Char, Int64} with 6 entries: 'b' => 4 julia> optcode = optimize_code(eincode, size_dict, optimizer) -ba, ab -> a -├─ bcf, fac -> ba -│ ├─ e, bcef -> bcf -│ │ ├─ e -│ │ └─ bcef -│ └─ df, acd -> fac -│ ├─ df -│ └─ acd -└─ ab +ab, ba -> a +├─ ab +└─ bcf, acf -> ba + ├─ bcef, e -> bcf + │ ├─ bcef + │ └─ e + └─ acd, df -> acf + ├─ acd + └─ df ``` """ Base.@kwdef struct Treewidth{EL <: EliminationAlgorithm} <: CodeOptimizer @@ -64,149 +57,307 @@ The `BT` algorithm is an exact solver for the treewidth problem that implemented const ExactTreewidth = Treewidth{SafeRules{BT, MMW{3}, MF}} ExactTreewidth() = Treewidth() -# calculates the exact treewidth of a graph using TreeWidthSolver.jl. It takes an incidence list representation of the graph (`incidence_list`) and a dictionary of logarithm base 2 edge sizes (`log2_edge_sizes`) as input. -# Return: a `ContractionTree` representing the contraction process. -# -# - `incidence_list`: An incidence list representation of the graph. -# - `log2_edge_sizes`: A dictionary of logarithm base 2 edge sizes. -# - `alg`: The algorithm to use for the treewidth calculation. -function treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes, alg) where {VT,ET} - indices = collect(keys(incidence_list.e2v)) - tensors = collect(keys(incidence_list.v2e)) - weights = [log2_edge_sizes[e] for e in indices] - line_graph = il2lg(incidence_list, indices) - - scalars = [i for i in tensors if isempty(incidence_list.v2e[i])] - contraction_trees = Vector{Union{ContractionTree, VT}}() - - # avoid the case that the line graph is not connected - for vertice_ids in connected_components(line_graph) - lg = induced_subgraph(line_graph, vertice_ids)[1] - lg_indices = indices[vertice_ids] - lg_weights = weights[vertice_ids] - - # construct tree decomposition - perm, tree = cliquetree(lg_weights, lg; alg) # `tree` is a vector of cliques - permute!(lg_indices, perm) # `perm` is a permutation - - # construct elimination ordering - eo = map(Base.Iterators.reverse(tree)) do clique - # the vertices in `res` can be eliminated at the same time - res = residual(clique) # `res` is a unit range - return @view lg_indices[res] - end +""" + optimize_treewidth(optimizer, eincode, size_dict) + +Optimizing the contraction order via solve the exact tree width of the line graph corresponding to the eincode and return a `NestedEinsum` object. +Check the docstring of `treewidth_method` for detailed explaination of other input arguments. +""" +function optimize_treewidth(optimizer::Treewidth, code::AbstractEinsum, size_dict::Dict; binary::Bool=true) + optimize_treewidth(optimizer, getixsv(code), getiyv(code), size_dict; binary) +end - lg_e2v = Dict{ET, Vector{VT}}() - lg_v2e = Dict{VT, Vector{ET}}() +function optimize_treewidth(optimizer::Treewidth, ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, size_dict::Dict{L, Int}; binary::Bool=true) where {L} + marker = zeros(Int, max(length(ixs) + 1, length(size_dict))) - for es in eo, e in es - vs = lg_e2v[e] = incidence_list.e2v[e] + # construct incidence matrix `ve` + # indices + # [ ] + # tensors [ ve ] + # [ ] + # we only care about the sparsity pattern + weights, ev, ve, el = einexpr_to_matrix!(marker, ixs, iy, size_dict) - for v in vs - if !haskey(lg_v2e, v) - lg_v2e[v] = ET[] - end + # compute a tree (forest) decomposition of `ve` + tree = matrix_to_tree!(marker, weights, ev, ve, el, optimizer.alg) - push!(lg_v2e[v], e) - end - end + # transform tree decomposition in contraction tree + code = tree_to_einexpr!(marker, tree, ve, el, ixs, iy) - lg_incidence_list = IncidenceList(lg_v2e, lg_e2v, ET[]) - contraction_tree = eo2ct(eo, lg_incidence_list, log2_edge_sizes) - push!(contraction_trees, contraction_tree) + if binary + # binarize contraction tree + code = _optimize_code(code, size_dict, GreedyMethod()) end - # add the scalars back to the contraction tree - return reduce((x,y) -> ContractionTree(x, y), contraction_trees ∪ scalars) + return code end -# transform incidence list to line graph -function il2lg(incidence_list::IncidenceList{VT, ET}, indicies::Vector{ET}) where {VT, ET} +""" + einexpr_to_matrix!(marker, ixs, iy, size_dict) - line_graph = SimpleGraph(length(indicies)) - - for (i, e) in enumerate(indicies) - for v in incidence_list.e2v[e] - for ej in incidence_list.v2e[v] - if e != ej add_edge!(line_graph, i, findfirst(==(ej), indicies)) end +Construct the weighted incidence matrix correponding to an Einstein summation expression. +Returns a quadruple (weights, ev, ve, el). + +Each Einstein summation expression has a set E ⊆ L of indices, a set V := {1, …, |V|} of +(inner) tensors, and an outer tensor * := |V| + 1. Each tensor v ∈ V is incident to a sequence +ixs[v] of indices, and the outer tensor is incident to the sequence iy. Note that an index +can appear multiple times in ixs[v], e.g. + + ixs[v] = ('a', 'a', 'b'). + +Each index l ∈ E also has a positive dimension, given by size_dict[l]. + +The function `einexpr_to_matrix` does two things. First of all, it enumerates the index set +E, mapping each index to a distinct natural number. + + el: {1, …, |E|} → E + le: E → {1, …, |E|} + +Next, it constructs a vector weights: {1, …, |E|} → [0, ∞) satisfying + + weights[e] := log2(size_dict[el[e]]), + +and a sparse matrix ve: {1, …, |V| + 1} × {1, …, |E|} → {0, 1} satisfying + + ve[v, e] := { 1 if el[e] is incident to v + { 0 otherwise + +We can think of the pair H := (weights, ve) as an edge-weighted hypergraph +with incidence matrix ve. +""" +function einexpr_to_matrix!(marker::AbstractVector{Int}, ixs::AbstractVector{<:AbstractVector{L}}, iy::AbstractVector{L}, size_dict::AbstractDict{L}) where {L} + m = length(size_dict) + n = length(ixs) + 1 + o = sum(length, ixs) + length(iy) + + # construct incidence matrix `ve` + # indices + # [ ] + # tensors [ ve ] + # [ ] + # we only care about the sparsity pattern + le = sizehint!(Dict{L, Int}(), m); el = sizehint!(L[], m) # el ∘ le = id + weights = sizehint!(Float64[], m) + colptr = sizehint!(Int[1], n + 1) + rowval = sizehint!(Int[], o) + nzval = sizehint!(Int[], o) + + # for all tensors v... + for (v, ix) in enumerate(ixs) + # for each index l incident to v... + for l in ix + # let e := le[l] + if haskey(le, l) + e = le[l] + else + push!(weights, log2(size_dict[l])) + push!(el, l) + e = le[l] = length(el) end - end - end - return line_graph -end + # if l has not been seen before in ixs[v]... + if marker[e] < v + # mark e as seen + marker[e] = v -# transform elimination order to contraction tree -function eo2ct(elimination_order::Vector{<:AbstractVector{TL}}, incidence_list::IncidenceList{VT, ET}, log2_edge_sizes) where {TL, VT, ET} - eo = copy(elimination_order) - incidence_list = copy(incidence_list) - contraction_tree_nodes = Vector{Union{VT, ContractionTree}}(collect(keys(incidence_list.v2e))) - tensors_list = Dict{VT, Int}() - for (i, v) in enumerate(contraction_tree_nodes) - tensors_list[v] = i + # set ev[e, v] := 1 + push!(rowval, e) + push!(nzval, 1) + end + end + + push!(colptr, length(rowval) + 1) end - flag = contraction_tree_nodes[1] - - while !isempty(eo) - eliminated_vertices = pop!(eo) # e is a vector of vertices, which are eliminated at the same time - vs = unique!(vcat([incidence_list.e2v[ei] for ei in eliminated_vertices if haskey(incidence_list.e2v, ei)]...)) # the tensors to be contracted, since they are connected to the eliminated vertices - if length(vs) >= 2 - sub_list_indices = unique!(vcat([incidence_list.v2e[v] for v in vs]...)) # the vertices connected to the tensors to be contracted - sub_list_open_indices = setdiff(sub_list_indices, eliminated_vertices) # the vertices connected to the tensors to be contracted but not eliminated - vmap = Dict([i => incidence_list.v2e[v] for (i, v) in enumerate(vs)]) - sub_list = IncidenceList(vmap; openedges=sub_list_open_indices) # the subgraph of the contracted tensors - sub_tree, scs, tcs = tree_greedy(sub_list, log2_edge_sizes; α=0.0, temperature=0.0) # optmize the subgraph with greedy method - sub_tree = expand_indices(sub_tree, Dict([i => v for (i, v) in enumerate(vs)])) - vi = contract_tree!(incidence_list, sub_tree, log2_edge_sizes, scs, tcs) # insert the contracted tensors back to the total graph - contraction_tree_nodes[tensors_list[vi]] = st2ct(sub_tree, tensors_list, contraction_tree_nodes) - flag = vi + # v is the outer tensor + v = length(colptr) + + # for each index l incident to v... + for l in iy + # let e := le[l] + if haskey(le, l) + e = le[l] + else + push!(weights, log2(size_dict[l])) + push!(el, l) + e = le[l] = length(el) end + + # if l has not been seen before in iy... + if marker[e] < v + # mark e as seen + marker[e] = v + + # set ev[e, v] = 1 + push!(rowval, e) + push!(nzval, 1) + end end - return contraction_tree_nodes[tensors_list[flag]] + push!(colptr, length(rowval) + 1) + + m = length(el) + n = length(colptr) - 1 + + ev = SparseMatrixCSC{Int, Int}(m, n, colptr, rowval, nzval) + ve = copy(transpose(ev)) + return weights, ev, ve, el end -function expand_indices(sub_tree::Union{ContractionTree, VT}, vmap::Dict{Int, VT}) where{VT} - if sub_tree isa ContractionTree - return ContractionTree(expand_indices(sub_tree.left, vmap), expand_indices(sub_tree.right, vmap)) - else - return vmap[sub_tree] +""" + matrix_to_tree!(marker, weights, ev, ve, el, alg) + +Construct a tree decomposition of an edge-weighted hypergraph using +the elimination algorithm `alg`. We ensure that the indices incident +to the outer tensor are contained in the root bag of the tree decomposition. +""" +function matrix_to_tree!(marker::AbstractVector{Int}, weights::AbstractVector{Float64}, ev::SparseMatrixCSC{Int, Int}, ve::SparseMatrixCSC{Int, Int}, el::AbstractVector{L}, alg::EliminationAlgorithm) where {L} + n, m = size(ve); tag = n + 1 + + # construct line graph `ee` + # indices + # [ ] + # indices [ ee ] + # [ ] + # we only care about the sparsity pattern + ee = ve' * ve + + # compute a tree (forest) decomposition of ee + perm, tree = cliquetree(weights, ee; alg=ConnectedComponents(alg)) + + # find the bag containing iy, call it root + root = length(tree) + + # mark the indices in iy + for e in view(rowvals(ev), nzrange(ev, n)) + marker[e] = tag end -end -function st2ct(sub_tree::Union{ContractionTree, VT}, tensors_list::Dict{VT, Int}, contraction_tree_nodes::Vector) where{VT} - if sub_tree isa ContractionTree - return ContractionTree(st2ct(sub_tree.left, tensors_list, contraction_tree_nodes), st2ct(sub_tree.right, tensors_list, contraction_tree_nodes)) - else - return contraction_tree_nodes[tensors_list[sub_tree]] + # the first bag containing an index in iy + # must contain all of iy + for (b, bag) in enumerate(tree) + root < length(tree) && break + + for e in residual(bag) + root < length(tree) && break + + if marker[perm[e]] == tag + root = b + end + end end + + # make root a root node of the tree decomposition + permute!(perm, cliquetree!(tree, root)) + + # permute incidence matrix `ve` and label vector `el`. + permute!(ve, axes(ve, 1), perm) + permute!(el, perm) + + return tree end """ - optimize_treewidth(optimizer, eincode, size_dict) + tree_to_einexpr!(marker, tree, ve, el, ixs, iy) -Optimizing the contraction order via solve the exact tree width of the line graph corresponding to the eincode and return a `NestedEinsum` object. -Check the docstring of `treewidth_method` for detailed explaination of other input arguments. +Transform a tree decomposition into a contraction tree. """ -function optimize_treewidth(optimizer::Treewidth{EL}, code::AbstractEinsum, size_dict::Dict) where {EL} - optimize_treewidth(optimizer, getixsv(code), getiyv(code), size_dict) -end -function optimize_treewidth(optimizer::Treewidth{EL}, ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, size_dict::Dict{L,TI}) where {L, TI, EL} - if length(ixs) <= 2 - return NestedEinsum(NestedEinsum{L}.(1:length(ixs)), EinCode(ixs, iy)) +function tree_to_einexpr!(marker::AbstractVector{Int}, tree::CliqueTree{Int, Int}, ve::SparseMatrixCSC{Int, Int}, el::AbstractVector{L}, ixs::AbstractVector{<:AbstractVector{L}}, iy::AbstractVector{L}) where {L} + n, m = size(ve); tag = n + 2 + + # dynamic programming + stack = NestedEinsum{L}[] + + # for each bag b... + for (b, bag) in enumerate(tree) + # sep is the separator at b + sep = separator(bag) + + # res is the residual at b + res = residual(bag) + + # code is the Einstein summation expression at b + code = NestedEinsum(NestedEinsum{L}[], EinCode(Vector{L}[], L[])) + + for e in sep + push!(code.eins.iy, el[e]) + end + + # for each index e in the residual... + for e in res + # for each tensor v indicent to e... + for v in view(rowvals(ve), nzrange(ve, e)) + # if has not been seen before... + if marker[v] < tag + # mark v as seen + marker[v] = tag + + # if v is the outer tensor... + if v == n + # expose iy + append!(code.eins.iy, iy) + # if v is an inner tensor... + else + # make v a child of code + push!(code.args, NestedEinsum{L}(v)) + push!(code.eins.ixs, ixs[v]) + end + end + end + end + + # for each child bag of b... + for _ in childindices(tree, b) + # the Einstein summation expression corresponding to + # the child bag is at the top of the stack + child = pop!(stack) + + # make this expression a child of code + push!(code.args, child) + push!(code.eins.ixs, child.eins.iy) + end + + # push code to the stack + push!(stack, code) end - log2_edge_sizes = Dict{L,Float64}() - for (k, v) in size_dict - log2_edge_sizes[k] = log2(v) + + # we now have an expression for each root of the tree decomposition. + # merge these together into a single Einstein expression code. + if isone(length(stack)) + code = only(stack) + else + code = NestedEinsum(NestedEinsum{L}[], EinCode(Vector{L}[], L[])) + append!(code.eins.iy, iy) + + while !isempty(stack) + child = pop!(stack) + push!(code.args, child) + push!(code.eins.ixs, child.eins.iy) + end end - # complete all open edges as a clique, connected with a dummy tensor - incidence_list = IncidenceList(Dict([i=>ixs[i] for i=1:length(ixs)] ∪ [(length(ixs) + 1 => iy)])) - tree = treewidth_method(incidence_list, log2_edge_sizes, optimizer.alg) + # append scalars to code + for (v, ix) in enumerate(ixs) + if isempty(ix) + push!(code.args, NestedEinsum{L}(v)) + push!(code.eins.ixs, ix) + end + end + + return code +end - # remove the dummy tensor added for open edges - optcode = parse_eincode!(incidence_list, tree, 1:length(ixs) + 1, size_dict)[2] +# no longer used +function il2lg(incidence_list::IncidenceList{VT, ET}, indicies::Vector{ET}) where {VT, ET} + line_graph = SimpleGraph(length(indicies)) + + for (i, e) in enumerate(indicies) + for v in incidence_list.e2v[e] + for ej in incidence_list.v2e[v] + if e != ej add_edge!(line_graph, i, findfirst(==(ej), indicies)) end + end + end + end - return pivot_tree(optcode, length(ixs) + 1) + return line_graph end diff --git a/test/treewidth.jl b/test/treewidth.jl index bc50b66..cba53af 100644 --- a/test/treewidth.jl +++ b/test/treewidth.jl @@ -97,8 +97,9 @@ end end @testset "trace operation" begin - code = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'a', 'd'], ['b', 'c', 'e', 'f']], Char['z']) - size_dict = Dict([c=>2 for c in ['a', 'b', 'c', 'd', 'e', 'f']]..., 'z'=>2) + code = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'a', 'd'], ['b', 'c', 'e', 'f']], Char['f']) + size_dict = Dict([c=>2 for c in ['a', 'b', 'c', 'd', 'e', 'f']]) + tensors = [rand([size_dict[j] for j in ixs]...) for ixs in getixsv(code)] optcode = optimize_code(code, size_dict, Treewidth(; alg=AMF())) - @test optcode == OMEinsumContractionOrders.NestedEinsum([OMEinsumContractionOrders.NestedEinsum([OMEinsumContractionOrders.NestedEinsum{Char}(2), OMEinsumContractionOrders.NestedEinsum{Char}(1)], OMEinsumContractionOrders.EinCode([['a', 'a', 'd'], ['a', 'b']], ['b'])), OMEinsumContractionOrders.NestedEinsum{Char}(3)], OMEinsumContractionOrders.EinCode([['b'], ['b', 'c', 'e', 'f']], ['z'])) -end \ No newline at end of file + @test decorate(code)(tensors...) ≈ decorate(optcode)(tensors...) +end