Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/ref.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## Data structures and interfaces
```@autodocs
Modules = [OMEinsumContractionOrders]
Pages = ["Core.jl"]
Pages = ["Core.jl", "utils.jl"]
```

## Time and space complexity
Expand Down
160 changes: 63 additions & 97 deletions src/greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,24 @@ function tree_greedy(incidence_list::IncidenceList{Int, ET}, log2_edge_sizes; α
end
end

function contract_pair!(incidence_list, vi, vj, log2_edge_sizes)
log2dim(legs) = isempty(legs) ? 0 : sum(l->log2_edge_sizes[l], legs) # for 1.5, you need this patch because `init` kw is not allowed.
# compute time complexity and output tensor
legsets = analyze_contraction(incidence_list, vi, vj)
D12,D01,D02,D012 = log2dim.(getfield.(Ref(legsets),3:6))
tc = D12+D01+D02+D012 # dangling legs D1 and D2 do not contribute
function contract_pair!(incidence_list::IncidenceList{Int,ET}, vi::Int, vj::Int, log2_edge_sizes) where {ET}
# Compute dimensions and edge lists in one pass
eout, eremove = ET[], ET[]
D1, D2, D12, D01, D02, D012 = compute_contraction_dims(incidence_list, log2_edge_sizes, vi, vj, eout, eremove)

tc = D12 + D01 + D02 + D012 # dangling legs D1 and D2 do not contribute
sc = D01 + D02 + D012 # space complexity is the output tensor size

# einsum code
eout = legsets.l01 ∪ legsets.l02 ∪ legsets.l012
code = (edges(incidence_list, vi), edges(incidence_list, vj)) => eout
sc = log2dim(eout)

# change incidence_list
delete_vertex!(incidence_list, vj)
change_edges!(incidence_list, vi, eout)
for e in eout
replace_vertex!(incidence_list, e, vj=>vi)
end
remove_edges!(incidence_list, legsets.l1 ∪ legsets.l2 ∪ legsets.l12)
remove_edges!(incidence_list, eremove)
return tc, sc, code
end

Expand Down Expand Up @@ -128,77 +127,69 @@ function find_best_cost!(temperature::TT, cost_values::PriorityQueue{PT}, cost_g
end
end

function analyze_contraction(incidence_list::IncidenceList{Int,ET}, vi::Int, vj::Int) where {ET}
"""
compute_contraction_dims(incidence_list, log2_edge_sizes, vi, vj, eout, eremove) -> (D1, D2, D12, D01, D02, D012)

Compute the log2 dimensions and edge lists for contracting vertices `vi` and `vj`.
Returns a tuple of six Float64 dimension values:
- D1: edges only in vi and internal
- D2: edges only in vj and internal
- D12: edges in both vi and vj and internal
- D01: edges only in vi and external
- D02: edges only in vj and external
- D012: edges in both vi and vj and external
"""
function compute_contraction_dims(incidence_list, log2_edge_sizes, vi, vj, eout, eremove)
ei = edges(incidence_list, vi)
ej = edges(incidence_list, vj)
leg012,leg12,leg1,leg2,leg01,leg02 = ET[], ET[], ET[], ET[], ET[], ET[]
# external legs
for leg in ei ∪ ej

# Initialize dimension accumulators
D1, D2, D12, D01, D02, D012 = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0

# Process edges from vi
for leg in ei
isext = leg ∈ incidence_list.openedges || !all(x->x==vi || x==vj, vertices(incidence_list, leg))
in_ej = leg ∈ ej
leg_size = log2_edge_sizes[leg]

if isext
if leg ∈ ei
if leg ∈ ej
push!(leg012, leg)
else
push!(leg01, leg)
end
eout !== nothing && push!(eout, leg)
if in_ej
D012 += leg_size
else
push!(leg02, leg)
D01 += leg_size
end
else
if leg ∈ ei
if leg ∈ ej
push!(leg12, leg)
else
push!(leg1, leg)
end
eremove !== nothing && push!(eremove, leg)
if in_ej
D12 += leg_size
else
push!(leg2, leg)
D1 += leg_size
end
end
end
return LegInfo(leg1, leg2, leg12, leg01, leg02, leg012)
end

function analyze_contraction_fast(incidence_list::IncidenceList{Int,ET}, vi::Int, vj::Int, log2_edge_sizes::AbstractDict{ET, T}) where {ET, T}
ei = edges(incidence_list, vi)
ej = edges(incidence_list, vj)

D1 = D2 = D12 = D01 = D02 = D012 = zero(T)

# external legs
for leg in ei ∪ ej
dim = log2_edge_sizes[leg]
isext = leg ∈ incidence_list.openedges || !all(x->x==vi || x==vj, vertices(incidence_list, leg))

if isext
if leg ∈ ei
if leg ∈ ej
D012 += dim
else
D01 += dim
end
else
D02 += dim
end
else
if leg ∈ ei
if leg ∈ ej
D12 += dim
else
D1 += dim
end

# Process edges from vj that are not in vi
for leg in ej
if leg ∉ ei
isext = leg ∈ incidence_list.openedges || !all(x->x==vi || x==vj, vertices(incidence_list, leg))
leg_size = log2_edge_sizes[leg]

if isext
eout !== nothing && push!(eout, leg)
D02 += leg_size
else
D2 += dim
eremove !== nothing && push!(eremove, leg)
D2 += leg_size
end
end
end

return D1, D2, D12, D01, D02, D012
return (D1, D2, D12, D01, D02, D012)
end

function greedy_loss(α, incidence_list, log2_edge_sizes, vi, vj)
D1, D2, D12, D01, D02, D012 = analyze_contraction_fast(incidence_list, vi, vj, log2_edge_sizes)
D1, D2, D12, D01, D02, D012 = compute_contraction_dims(incidence_list, log2_edge_sizes, vi, vj, nothing, nothing)
loss = exp2(D01+D02+D012) - α * (exp2(D01+D12+D012) + exp2(D02+D12+D012)) # out - in
return loss
end
Expand Down Expand Up @@ -264,48 +255,23 @@ end
Greedy optimizing the contraction order and return a `NestedEinsum` object.
Check the docstring of `tree_greedy` for detailed explaination of other input arguments.
"""
function optimize_greedy(code::EinCode{L}, size_dict::Dict{L, T2}; α, temperature) where {L, T2}
optimize_greedy(getixsv(code), getiyv(code), size_dict; α, temperature)
end
function convert_label(ne::NestedEinsum, labelmap::Dict{T1,T2}) where {T1,T2}
isleaf(ne) && return NestedEinsum{T2}(ne.tensorindex)
eins = EinCode([getindex.(Ref(labelmap), ix) for ix in ne.eins.ixs], getindex.(Ref(labelmap), ne.eins.iy))
NestedEinsum([convert_label(arg, labelmap) for arg in ne.args], eins)
function optimize_greedy(code::AbstractEinsum, size_dict::Dict{L, T2}; α, temperature) where {L, T2}
optimize_greedy_log2size(code, _log2_size_dict(size_dict); α, temperature)
end

function optimize_greedy(ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, size_dict::Dict{L}; α, temperature) where {L}
if length(ixs) <= 2
return NestedEinsum(NestedEinsum{L}.(1:length(ixs)), EinCode(ixs, iy))
end
log2_edge_sizes = Dict{L,Float64}()
for (k, v) in size_dict
log2_edge_sizes[k] = log2(v)
end
incidence_list = IncidenceList(Dict([i=>ixs[i] for i=1:length(ixs)]); openedges=iy)
tree, _, _ = tree_greedy(incidence_list, log2_edge_sizes; α, temperature)
parse_eincode!(incidence_list, tree, 1:length(ixs), size_dict)[2]
function optimize_greedy_log2size(code::EinCode{L}, log2_size_dict::Dict{L}; α, temperature) where {L}
_optimize_greedy_log2size(getixsv(code), getiyv(code), log2_size_dict; α, temperature)
end

function optimize_greedy_log2(code::EinCode{L}, size_dict::Dict{L}, size_dict_log2::Dict{L}; α, temperature) where {L}
ixs = getixsv(code)
iy = getiyv(code)

function _optimize_greedy_log2size(ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, log2_size_dict::Dict{L}; α, temperature) where {L}
if length(ixs) <= 2
return NestedEinsum(NestedEinsum{L}.(1:length(ixs)), EinCode(ixs, iy))
end

incidence_list = IncidenceList(Dict([i=>ixs[i] for i=1:length(ixs)]); openedges=iy)
tree, _, _ = tree_greedy(incidence_list, size_dict_log2; α, temperature)
return parse_eincode!(incidence_list, tree, 1:length(ixs), size_dict)[2]
tree, _, _ = tree_greedy(incidence_list, log2_size_dict; α, temperature)
parse_eincode!(incidence_list, tree, 1:length(ixs), log2_size_dict)[2]
end

function optimize_greedy(code::E, size_dict::AbstractDict{L}; α, temperature) where {E <: NestedEinsum, L}
size_dict_log2 = Dict{L, Float64}()

for (lbl, dim) in size_dict
size_dict_log2[lbl] = log2(dim)
end

function optimize_greedy_log2size(code::NestedEinsum{L}, log2_size_dict; α, temperature) where {L}
# construct first-child next-sibling representation of `code`
queue = [code]
child = [0]
Expand Down Expand Up @@ -343,14 +309,14 @@ function optimize_greedy(code::E, size_dict::AbstractDict{L}; α, temperature) w
if isleaf(code)
push!(queue, code)
else
args = E[]
args = NestedEinsum{L}[]

for _ in code.args
push!(args, pop!(queue))
end

if length(args) > 2
code = replace_args(optimize_greedy_log2(code.eins, size_dict, size_dict_log2; α, temperature), args)
code = replace_args(optimize_greedy_log2size(code.eins, log2_size_dict; α, temperature), args)
else
code = NestedEinsum(args, code.eins)
end
Expand Down
2 changes: 1 addition & 1 deletion src/incidencelist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function IncidenceList(v2e::Dict{VT,Vector{ET}}; openedges=ET[]) where {VT,ET}
IncidenceList(v2e, e2v, openedges)
end

Base.copy(il::IncidenceList) = IncidenceList(deepcopy(il.v2e), deepcopy(il.e2v), copy(il.openedges))
Base.copy(il::IncidenceList) = IncidenceList(Dict([k=>copy(v) for (k,v) in il.v2e]), Dict([k=>copy(v) for (k,v) in il.e2v]), copy(il.openedges))

function neighbors(il::IncidenceList{VT}, v) where VT
res = VT[]
Expand Down
3 changes: 2 additions & 1 deletion src/treewidth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ function optimize_treewidth(optimizer::Treewidth, code::AbstractEinsum, size_dic
end

function optimize_treewidth(optimizer::Treewidth, ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, size_dict::Dict{L, Int}; binary::Bool=true) where {L}
log2_size_dict = _log2_size_dict(size_dict)
marker = zeros(Int, max(length(ixs) + 1, length(size_dict)))

# construct incidence matrix `ve`
Expand All @@ -86,7 +87,7 @@ function optimize_treewidth(optimizer::Treewidth, ixs::AbstractVector{<:Abstract

if binary
# binarize contraction tree
code = _optimize_code(code, size_dict, GreedyMethod())
code = optimize_greedy_log2size(code, log2_size_dict; α = 0.0, temperature = 0.0)
end

return code
Expand Down
20 changes: 20 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,23 @@ function log2sumexp2(s)
ms = maximum(s)
return log2(sum(x->exp2(x - ms), s)) + ms
end

function _log2_size_dict(size_dict::Dict{L, T2}) where {L, T2}
log2_size_dict = Dict{L,Float64}()
for (k, v) in size_dict
log2_size_dict[k] = log2(v)
end
return log2_size_dict
end

"""
convert_label(ne::NestedEinsum, labelmap::Dict{T1,T2}) where {T1,T2}

Convert the labels of a `NestedEinsum` object to new labels.
`labelmap` is a dictionary that maps the old labels to the new labels.
"""
function convert_label(ne::NestedEinsum, labelmap::Dict{T1,T2}) where {T1,T2}
isleaf(ne) && return NestedEinsum{T2}(ne.tensorindex)
eins = EinCode([getindex.(Ref(labelmap), ix) for ix in ne.eins.ixs], getindex.(Ref(labelmap), ne.eins.iy))
NestedEinsum([convert_label(arg, labelmap) for arg in ne.args], eins)
end
Loading
Loading