diff --git a/src/greedy.jl b/src/greedy.jl index d784133..fac568e 100644 --- a/src/greedy.jl +++ b/src/greedy.jl @@ -160,10 +160,45 @@ function analyze_contraction(incidence_list::IncidenceList{Int,ET}, vi::Int, vj: return LegInfo(leg1, leg2, leg12, leg01, leg02, leg012) end +function analyze_contraction_fast(incidence_list::IncidenceList{Int,ET}, vi::Int, vj::Int, log2_edge_sizes::AbstractDict{ET, T}) where {ET, T} + ei = edges(incidence_list, vi) + ej = edges(incidence_list, vj) + + D1 = D2 = D12 = D01 = D02 = D012 = zero(T) + + # external legs + for leg in ei ∪ ej + dim = log2_edge_sizes[leg] + isext = leg ∈ incidence_list.openedges || !all(x->x==vi || x==vj, vertices(incidence_list, leg)) + + if isext + if leg ∈ ei + if leg ∈ ej + D012 += dim + else + D01 += dim + end + else + D02 += dim + end + else + if leg ∈ ei + if leg ∈ ej + D12 += dim + else + D1 += dim + end + else + D2 += dim + end + end + end + + return D1, D2, D12, D01, D02, D012 +end + function greedy_loss(α, incidence_list, log2_edge_sizes, vi, vj) - log2dim(legs) = isempty(legs) ? 0 : sum(l->log2_edge_sizes[l], legs) # for 1.5, you need this patch because `init` kw is not allowed. - legs = analyze_contraction(incidence_list, vi, vj) - D1,D2,D12,D01,D02,D012 = log2dim.(getfield.(Ref(legs), 1:6)) + D1, D2, D12, D01, D02, D012 = analyze_contraction_fast(incidence_list, vi, vj, log2_edge_sizes) loss = exp2(D01+D02+D012) - α * (exp2(D01+D12+D012) + exp2(D02+D12+D012)) # out - in return loss end @@ -251,7 +286,26 @@ function optimize_greedy(ixs::AbstractVector{<:AbstractVector}, iy::AbstractVect parse_eincode!(incidence_list, tree, 1:length(ixs), size_dict)[2] end -function optimize_greedy(code::E, size_dict; α, temperature) where E <: NestedEinsum +function optimize_greedy_log2(code::EinCode{L}, size_dict::Dict{L}, size_dict_log2::Dict{L}; α, temperature) where {L} + ixs = getixsv(code) + iy = getiyv(code) + + if length(ixs) <= 2 + return NestedEinsum(NestedEinsum{L}.(1:length(ixs)), EinCode(ixs, iy)) + end + + incidence_list = IncidenceList(Dict([i=>ixs[i] for i=1:length(ixs)]); openedges=iy) + tree, _, _ = tree_greedy(incidence_list, size_dict_log2; α, temperature) + return parse_eincode!(incidence_list, tree, 1:length(ixs), size_dict)[2] +end + +function optimize_greedy(code::E, size_dict::AbstractDict{L}; α, temperature) where {E <: NestedEinsum, L} + size_dict_log2 = Dict{L, Float64}() + + for (lbl, dim) in size_dict + size_dict_log2[lbl] = log2(dim) + end + # construct first-child next-sibling representation of `code` queue = [code] child = [0] @@ -296,7 +350,7 @@ function optimize_greedy(code::E, size_dict; α, temperature) where E <: NestedE end if length(args) > 2 - code = replace_args(optimize_greedy(code.eins, size_dict; α, temperature), args) + code = replace_args(optimize_greedy_log2(code.eins, size_dict, size_dict_log2; α, temperature), args) else code = NestedEinsum(args, code.eins) end