Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ LuxorTensorPlot = ["LuxorGraphPlot"]
[compat]
AbstractTrees = "0.3, 0.4"
Aqua = "0.8"
CliqueTrees = "1.5.0"
CliqueTrees = "1.12"
DataStructures = "0.18"
Documenter = "1.10.1"
Graphs = "1"
Expand Down
4 changes: 2 additions & 2 deletions src/OMEinsumContractionOrders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ module OMEinsumContractionOrders
using JSON
using SparseArrays
using StatsBase
using Base: RefValue
using Base: RefValue, oneto
using Base.Threads
using AbstractTrees
using TreeWidthSolver
using TreeWidthSolver.Graphs
using DataStructures: PriorityQueue, enqueue!, dequeue!, peek, dequeue_pair!
import CliqueTrees
using CliqueTrees: cliquetree, residual, EliminationAlgorithm, MMW, BFS, MCS, LexBFS, RCMMD, RCMGL, MCSM, LexM, AMF, MF, MMD, MF, BT, SafeRules, KaHyParND, METISND, ND, BestWidth
using CliqueTrees: cliquetree, cliquetree!, separator, residual, EliminationAlgorithm, MMW, BFS, MCS, LexBFS, RCMMD, RCMGL, MCSM, LexM, AMF, MF, MMD, MF, BT, SafeRules, KaHyParND, METISND, ND, BestWidth, ConnectedComponents

# interfaces
export simplify_code, optimize_code, slice_code, optimize_permute, label_elimination_order, uniformsize, ScoreFunction
Expand Down
21 changes: 14 additions & 7 deletions src/hypernd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,33 +37,39 @@ The optimizer is implemented using the tree decomposition library
dis::D = KaHyParND()
algs::A = (MF(), AMF(), MMD())
level::Int = 6
width::Int = 120
imbalances::StepRange{Int, Int} = 130:1:130
width::Int = 50
scale::Int = 100
imbalances::StepRange{Int, Int} = 100:10:800
score::ScoreFunction = ScoreFunction()
end

function optimize_hyper_nd(optimizer::HyperND, code, size_dict)
function optimize_hyper_nd(optimizer::HyperND, code::AbstractEinsum, size_dict::AbstractDict; binary::Bool=true)
dis = optimizer.dis
algs = optimizer.algs
level = optimizer.level
width = optimizer.width
scale = optimizer.scale
imbalances = optimizer.imbalances
score = optimizer.score

minscore = typemax(Float64)
local mincode

for imbalance in imbalances
curalg = SafeRules(ND(BestWidth(algs), dis; level, width, imbalance))
curoptimizer = Treewidth(; alg=curalg)
curcode = _optimize_code(code, size_dict, curoptimizer)
curalg = SafeRules(ND(BestWidth(algs), dis; level, width, scale, imbalance))
curopt = Treewidth(; alg=curalg)
curcode = optimize_treewidth(curopt, code, size_dict; binary=false)
curtc, cursc, currw = __timespacereadwrite_complexity(curcode, size_dict)

if score(curtc, cursc, currw) < minscore
minscore, mincode = score(curtc, cursc, currw), curcode
end
end

if binary
mincode = _optimize_code(mincode, size_dict, GreedyMethod())
end

return mincode
end

Expand All @@ -77,7 +83,8 @@ function Base.show(io::IO, ::MIME"text/plain", optimizer::HyperND{D, A}) where {

println(io, " level: $(optimizer.level)")
println(io, " width: $(optimizer.width)")
println(io, " scale: $(optimizer.scale)")
println(io, " imbalances: $(optimizer.imbalances)")
println(io, " target: $(optimizer.target)")
println(io, " score: $(optimizer.score)")
return
end
273 changes: 157 additions & 116 deletions src/treewidth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,149 +64,190 @@ The `BT` algorithm is an exact solver for the treewidth problem that implemented
const ExactTreewidth = Treewidth{SafeRules{BT, MMW{3}, MF}}
ExactTreewidth() = Treewidth()

# calculates the exact treewidth of a graph using TreeWidthSolver.jl. It takes an incidence list representation of the graph (`incidence_list`) and a dictionary of logarithm base 2 edge sizes (`log2_edge_sizes`) as input.
# Return: a `ContractionTree` representing the contraction process.
#
# - `incidence_list`: An incidence list representation of the graph.
# - `log2_edge_sizes`: A dictionary of logarithm base 2 edge sizes.
# - `alg`: The algorithm to use for the treewidth calculation.
function treewidth_method(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes, alg) where {VT,ET}
indices = collect(keys(incidence_list.e2v))
tensors = collect(keys(incidence_list.v2e))
weights = [log2_edge_sizes[e] for e in indices]
line_graph = il2lg(incidence_list, indices)

scalars = [i for i in tensors if isempty(incidence_list.v2e[i])]
contraction_trees = Vector{Union{ContractionTree, VT}}()

# avoid the case that the line graph is not connected
for vertice_ids in connected_components(line_graph)
lg = induced_subgraph(line_graph, vertice_ids)[1]
lg_indices = indices[vertice_ids]
lg_weights = weights[vertice_ids]

# construct tree decomposition
perm, tree = cliquetree(lg_weights, lg; alg) # `tree` is a vector of cliques
permute!(lg_indices, perm) # `perm` is a permutation

# construct elimination ordering
eo = map(Base.Iterators.reverse(tree)) do clique
# the vertices in `res` can be eliminated at the same time
res = residual(clique) # `res` is a unit range
return @view lg_indices[res]
end
"""
optimize_treewidth(optimizer, eincode, size_dict)

lg_e2v = Dict{ET, Vector{VT}}()
lg_v2e = Dict{VT, Vector{ET}}()
Optimizing the contraction order via solve the exact tree width of the line graph corresponding to the eincode and return a `NestedEinsum` object.
Check the docstring of `treewidth_method` for detailed explaination of other input arguments.
"""
function optimize_treewidth(optimizer::Treewidth{EL}, code::AbstractEinsum, size_dict::Dict; binary::Bool=true) where {EL}
optimize_treewidth(optimizer, getixsv(code), getiyv(code), size_dict; binary)
end

for es in eo, e in es
vs = lg_e2v[e] = incidence_list.e2v[e]
function optimize_treewidth(optimizer::Treewidth{EL}, ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, size_dict::Dict{L,TI}; binary::Bool=true) where {L, TI, EL}
marker = zeros(Int, max(length(size_dict), length(ixs) + 1))

# construct incidence matrix `ve`
# indices
# [ ]
# tensors [ ve ]
# [ ]
# we only care about the sparsity pattern
le = Dict{L, Int}(); el = L[] # el ∘ le = id
weights = Float64[]
colptr = Int[1]
rowval = Int[]
nzval = Int[]

for (v, ix) in enumerate(ixs)
for l in ix
if haskey(le, l)
e = le[l]
else
push!(weights, log2(size_dict[l]))
push!(el, l)
e = le[l] = length(el)
end

for v in vs
if !haskey(lg_v2e, v)
lg_v2e[v] = ET[]
end
if marker[e] < v
marker[e] = v
push!(rowval, e)
push!(nzval, 1)
end
end

push!(lg_v2e[v], e)
end
push!(colptr, length(rowval) + 1)
end

# add a "virtual" tensor with indices `iy`
v = length(colptr)

for l in iy
if haskey(le, l)
e = le[l]
else
push!(weights, log2(size_dict[l]))
push!(el, l)
e = le[l] = length(el)
end

lg_incidence_list = IncidenceList(lg_v2e, lg_e2v, ET[])
contraction_tree = eo2ct(eo, lg_incidence_list, log2_edge_sizes)
push!(contraction_trees, contraction_tree)
if marker[e] < v
marker[e] = v
push!(rowval, e)
push!(nzval, 1)
end
end

# add the scalars back to the contraction tree
return reduce((x,y) -> ContractionTree(x, y), contraction_trees ∪ scalars)
end
push!(colptr, length(rowval) + 1)

# transform incidence list to line graph
function il2lg(incidence_list::IncidenceList{VT, ET}, indicies::Vector{ET}) where {VT, ET}
m = length(el)
n = length(colptr) - 1

line_graph = SimpleGraph(length(indicies))

for (i, e) in enumerate(indicies)
for v in incidence_list.e2v[e]
for ej in incidence_list.v2e[v]
if e != ej add_edge!(line_graph, i, findfirst(==(ej), indicies)) end
ev = SparseMatrixCSC{Int, Int}(m, n, colptr, rowval, nzval)
ve = copy(transpose(ev))

# construct line graph `ee`
# indices
# [ ]
# indices [ ee ]
# [ ]
# we only care about the sparsity pattern
ee = ve' * ve

# compute a tree (forest) decomposition of `ee`
perm, tree = cliquetree(weights, ee; alg=ConnectedComponents(optimizer.alg))

# find the bag containing `iy`, call it `root`
root = length(tree)

for e in view(rowvals(ev), nzrange(ev, n))
marker[e] = -1
end

for (b, bag) in enumerate(tree)
root < length(tree) && break

for e in residual(bag)
root < length(tree) && break

if marker[perm[e]] == -1
root = b
end
end
end

return line_graph
end
# make `root` a root node of the tree decomposition
permute!(perm, cliquetree!(tree, root))

# transform elimination order to contraction tree
function eo2ct(elimination_order::Vector{<:AbstractVector{TL}}, incidence_list::IncidenceList{VT, ET}, log2_edge_sizes) where {TL, VT, ET}
eo = copy(elimination_order)
incidence_list = copy(incidence_list)
contraction_tree_nodes = Vector{Union{VT, ContractionTree}}(collect(keys(incidence_list.v2e)))
tensors_list = Dict{VT, Int}()
for (i, v) in enumerate(contraction_tree_nodes)
tensors_list[v] = i
end
# permute incidence matrix `ve`
permute!(el, perm)
permute!(ve, oneto(n), perm)

# dynamic programming
stack = NestedEinsum{L}[]

for (b, bag) in enumerate(tree)
sep = separator(bag)
res = residual(bag)
code = NestedEinsum(NestedEinsum{L}[], EinCode(Vector{L}[], L[]))

flag = contraction_tree_nodes[1]

while !isempty(eo)
eliminated_vertices = pop!(eo) # e is a vector of vertices, which are eliminated at the same time
vs = unique!(vcat([incidence_list.e2v[ei] for ei in eliminated_vertices if haskey(incidence_list.e2v, ei)]...)) # the tensors to be contracted, since they are connected to the eliminated vertices
if length(vs) >= 2
sub_list_indices = unique!(vcat([incidence_list.v2e[v] for v in vs]...)) # the vertices connected to the tensors to be contracted
sub_list_open_indices = setdiff(sub_list_indices, eliminated_vertices) # the vertices connected to the tensors to be contracted but not eliminated
vmap = Dict([i => incidence_list.v2e[v] for (i, v) in enumerate(vs)])
sub_list = IncidenceList(vmap; openedges=sub_list_open_indices) # the subgraph of the contracted tensors
sub_tree, scs, tcs = tree_greedy(sub_list, log2_edge_sizes; α=0.0, temperature=0.0) # optmize the subgraph with greedy method
sub_tree = expand_indices(sub_tree, Dict([i => v for (i, v) in enumerate(vs)]))
vi = contract_tree!(incidence_list, sub_tree, log2_edge_sizes, scs, tcs) # insert the contracted tensors back to the total graph
contraction_tree_nodes[tensors_list[vi]] = st2ct(sub_tree, tensors_list, contraction_tree_nodes)
flag = vi
for e in sep
push!(code.eins.iy, el[e])
end
end

return contraction_tree_nodes[tensors_list[flag]]
end
for e in res, v in view(rowvals(ve), nzrange(ve, e))
if marker[v] != -2
marker[v] = -2

function expand_indices(sub_tree::Union{ContractionTree, VT}, vmap::Dict{Int, VT}) where{VT}
if sub_tree isa ContractionTree
return ContractionTree(expand_indices(sub_tree.left, vmap), expand_indices(sub_tree.right, vmap))
else
return vmap[sub_tree]
if v == n
append!(code.eins.iy, iy)
else
push!(code.args, NestedEinsum{L}(v))
push!(code.eins.ixs, ixs[v])
end
end
end

for _ in childindices(tree, b)
child = pop!(stack)
push!(code.args, child)
push!(code.eins.ixs, child.eins.iy)
end

push!(stack, code)
end
end

function st2ct(sub_tree::Union{ContractionTree, VT}, tensors_list::Dict{VT, Int}, contraction_tree_nodes::Vector) where{VT}
if sub_tree isa ContractionTree
return ContractionTree(st2ct(sub_tree.left, tensors_list, contraction_tree_nodes), st2ct(sub_tree.right, tensors_list, contraction_tree_nodes))
# we now have an expression for each root
# of the tree decomposition
if isone(length(stack))
code = only(stack)
else
return contraction_tree_nodes[tensors_list[sub_tree]]
end
end
code = NestedEinsum(NestedEinsum{L}[], EinCode(Vector{L}[], L[]))
append!(code.eins.iy, iy)

"""
optimize_treewidth(optimizer, eincode, size_dict)
while !isempty(stack)
child = pop!(stack)
push!(code.args, child)
push!(code.eins.ixs, child.eins.iy)
end
end

Optimizing the contraction order via solve the exact tree width of the line graph corresponding to the eincode and return a `NestedEinsum` object.
Check the docstring of `treewidth_method` for detailed explaination of other input arguments.
"""
function optimize_treewidth(optimizer::Treewidth{EL}, code::AbstractEinsum, size_dict::Dict) where {EL}
optimize_treewidth(optimizer, getixsv(code), getiyv(code), size_dict)
end
function optimize_treewidth(optimizer::Treewidth{EL}, ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, size_dict::Dict{L,TI}) where {L, TI, EL}
if length(ixs) <= 2
return NestedEinsum(NestedEinsum{L}.(1:length(ixs)), EinCode(ixs, iy))
# append scalars to the root
for (v, ix) in enumerate(ixs)
if isempty(ix)
push!(code.args, NestedEinsum{L}(v))
push!(code.eins.ixs, ix)
end
end
log2_edge_sizes = Dict{L,Float64}()
for (k, v) in size_dict
log2_edge_sizes[k] = log2(v)

if binary
code = _optimize_code(code, size_dict, GreedyMethod())
end
# complete all open edges as a clique, connected with a dummy tensor
incidence_list = IncidenceList(Dict([i=>ixs[i] for i=1:length(ixs)] ∪ [(length(ixs) + 1 => iy)]))

tree = treewidth_method(incidence_list, log2_edge_sizes, optimizer.alg)
return code
end

# remove the dummy tensor added for open edges
optcode = parse_eincode!(incidence_list, tree, 1:length(ixs) + 1, size_dict)[2]
# no longer used
function il2lg(incidence_list::IncidenceList{VT, ET}, indicies::Vector{ET}) where {VT, ET}
line_graph = SimpleGraph(length(indicies))

for (i, e) in enumerate(indicies)
for v in incidence_list.e2v[e]
for ej in incidence_list.v2e[v]
if e != ej add_edge!(line_graph, i, findfirst(==(ej), indicies)) end
end
end
end

return pivot_tree(optcode, length(ixs) + 1)
return line_graph
end
Loading