Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions src/complexity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,34 @@ function __timespacereadwrite_complexity(ei::EinCode, size_dict)
_timespacereadwrite_complexity(getixsv(ei), getiyv(ei), log2_sizes)
end

function _timespacereadwrite_complexity(ei::NestedEinsum, log2_sizes::Dict{L,VT}) where {L,VT}
isleaf(ei) && return (VT(-Inf), VT(-Inf), VT(-Inf))
tcs = VT[]
scs = VT[]
rws = VT[]
for arg in ei.args
tc, sc, rw = _timespacereadwrite_complexity(arg, log2_sizes)
push!(tcs, tc)
push!(scs, sc)
push!(rws, rw)
function _timespacereadwrite_complexity(ei::NestedEinsum, log2_sizes::Dict{L, VT}) where {L, VT}
min = typemin(VT)

tcs = [min]
scs = [min]
rws = [min]
stack = [ei]

while !isempty(stack)
ei = pop!(stack)

if !isleaf(ei)
tc, sc, rw = _timespacereadwrite_complexity(
getixsv(ei.eins),
getiyv(ei.eins),
log2_sizes,
)

append!(stack, ei.args)
push!(tcs, tc)
push!(scs, sc)
push!(rws, rw)
end
end
tc2, sc2, rw2 = _timespacereadwrite_complexity(getixsv(ei.eins), getiyv(ei.eins), log2_sizes)
tc = log2sumexp2([tcs..., tc2])
sc = max(reduce(max, scs, init=zero(VT)), sc2)
rw = log2sumexp2([rws..., rw2])

tc = log2sumexp2(tcs)
sc = maximum(scs)
rw = log2sumexp2(rws)
return tc, sc, rw
end

Expand Down
64 changes: 55 additions & 9 deletions src/greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,62 @@ function optimize_greedy(ixs::AbstractVector{<:AbstractVector}, iy::AbstractVect
tree, _, _ = tree_greedy(incidence_list, log2_edge_sizes; α, temperature)
parse_eincode!(incidence_list, tree, 1:length(ixs), size_dict)[2]
end
function optimize_greedy(code::NestedEinsum, size_dict; α, temperature)
isleaf(code) && return code
args = optimize_greedy.(code.args, Ref(size_dict); α, temperature)
if length(code.args) > 2
# generate coarse grained hypergraph.
nested = optimize_greedy(code.eins, size_dict; α, temperature)
replace_args(nested, args)
else
NestedEinsum(args, code.eins)

function optimize_greedy(code::E, size_dict; α, temperature) where E <: NestedEinsum
# construct first-child next-sibling representation of `code`
queue = [code]
child = [0]
brother = [0]

for (i, code) in enumerate(queue)
if !isleaf(code)
for arg in code.args
push!(queue, arg)
push!(brother, child[i])
push!(child, 0)
child[i] = length(queue)
end
end
end

# construct postordering of `code`
order = similar(queue)
stack = [1]

for i in eachindex(queue)
j = pop!(stack); k = child[j]

while !iszero(k)
push!(stack, j); child[j] = brother[k]; j = k; k = child[j]
end

order[i] = queue[j]
end

# optimize `code` using dynamic programming
empty!(queue)

for code in order
if isleaf(code)
push!(queue, code)
else
args = E[]

for _ in code.args
push!(args, pop!(queue))
end

if length(args) > 2
code = replace_args(optimize_greedy(code.eins, size_dict; α, temperature), args)
else
code = NestedEinsum(args, code.eins)
end

push!(queue, code)
end
end

return only(queue)
end

function replace_args(nested::NestedEinsum{LT}, trueargs) where LT
Expand Down
8 changes: 6 additions & 2 deletions src/hypernd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ The optimizer is implemented using the tree decomposition library
score::ScoreFunction = ScoreFunction()
end

function optimize_hyper_nd(optimizer::HyperND, code::AbstractEinsum, size_dict::AbstractDict; binary::Bool=true)
function optimize_hyper_nd(optimizer::HyperND, code::AbstractEinsum, size_dict::Dict; binary::Bool=true)
optimize_hyper_nd(optimizer, getixsv(code), getiyv(code), size_dict; binary)
end

function optimize_hyper_nd(optimizer::HyperND, ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, size_dict::Dict; binary::Bool=true)
dis = optimizer.dis
algs = optimizer.algs
level = optimizer.level
Expand All @@ -63,7 +67,7 @@ function optimize_hyper_nd(optimizer::HyperND, code::AbstractEinsum, size_dict::
for imbalance in imbalances
curalg = SafeRules(ND(BestWidth(algs), dis; level, width, scale, imbalance))
curopt = Treewidth(; alg=curalg)
curcode = optimize_treewidth(curopt, code, size_dict; binary=false)
curcode = optimize_treewidth(curopt, ixs, iy, size_dict; binary=false)
curtc, cursc, currw = __timespacereadwrite_complexity(curcode, size_dict)

if score(curtc, cursc, currw) < minscore
Expand Down
Loading