diff --git a/docs/src/ref.md b/docs/src/ref.md index 3e90f33..6fb4939 100644 --- a/docs/src/ref.md +++ b/docs/src/ref.md @@ -3,7 +3,7 @@ ## Data structures and interfaces ```@autodocs Modules = [OMEinsumContractionOrders] -Pages = ["Core.jl"] +Pages = ["Core.jl", "utils.jl"] ``` ## Time and space complexity diff --git a/src/greedy.jl b/src/greedy.jl index fac568e..65be7c9 100644 --- a/src/greedy.jl +++ b/src/greedy.jl @@ -53,17 +53,16 @@ function tree_greedy(incidence_list::IncidenceList{Int, ET}, log2_edge_sizes; α end end -function contract_pair!(incidence_list, vi, vj, log2_edge_sizes) - log2dim(legs) = isempty(legs) ? 0 : sum(l->log2_edge_sizes[l], legs) # for 1.5, you need this patch because `init` kw is not allowed. - # compute time complexity and output tensor - legsets = analyze_contraction(incidence_list, vi, vj) - D12,D01,D02,D012 = log2dim.(getfield.(Ref(legsets),3:6)) - tc = D12+D01+D02+D012 # dangling legs D1 and D2 do not contribute +function contract_pair!(incidence_list::IncidenceList{Int,ET}, vi::Int, vj::Int, log2_edge_sizes) where {ET} + # Compute dimensions and edge lists in one pass + eout, eremove = ET[], ET[] + D1, D2, D12, D01, D02, D012 = compute_contraction_dims(incidence_list, log2_edge_sizes, vi, vj, eout, eremove) + + tc = D12 + D01 + D02 + D012 # dangling legs D1 and D2 do not contribute + sc = D01 + D02 + D012 # space complexity is the output tensor size # einsum code - eout = legsets.l01 ∪ legsets.l02 ∪ legsets.l012 code = (edges(incidence_list, vi), edges(incidence_list, vj)) => eout - sc = log2dim(eout) # change incidence_list delete_vertex!(incidence_list, vj) @@ -71,7 +70,7 @@ function contract_pair!(incidence_list, vi, vj, log2_edge_sizes) for e in eout replace_vertex!(incidence_list, e, vj=>vi) end - remove_edges!(incidence_list, legsets.l1 ∪ legsets.l2 ∪ legsets.l12) + remove_edges!(incidence_list, eremove) return tc, sc, code end @@ -128,77 +127,69 @@ function find_best_cost!(temperature::TT, cost_values::PriorityQueue{PT}, cost_g end end -function analyze_contraction(incidence_list::IncidenceList{Int,ET}, vi::Int, vj::Int) where {ET} +""" + compute_contraction_dims(incidence_list, log2_edge_sizes, vi, vj, eout, eremove) -> (D1, D2, D12, D01, D02, D012) + +Compute the log2 dimensions and edge lists for contracting vertices `vi` and `vj`. +Returns a tuple of six Float64 dimension values: +- D1: edges only in vi and internal +- D2: edges only in vj and internal +- D12: edges in both vi and vj and internal +- D01: edges only in vi and external +- D02: edges only in vj and external +- D012: edges in both vi and vj and external +""" +function compute_contraction_dims(incidence_list, log2_edge_sizes, vi, vj, eout, eremove) ei = edges(incidence_list, vi) ej = edges(incidence_list, vj) - leg012,leg12,leg1,leg2,leg01,leg02 = ET[], ET[], ET[], ET[], ET[], ET[] - # external legs - for leg in ei ∪ ej + + # Initialize dimension accumulators + D1, D2, D12, D01, D02, D012 = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 + + # Process edges from vi + for leg in ei isext = leg ∈ incidence_list.openedges || !all(x->x==vi || x==vj, vertices(incidence_list, leg)) + in_ej = leg ∈ ej + leg_size = log2_edge_sizes[leg] + if isext - if leg ∈ ei - if leg ∈ ej - push!(leg012, leg) - else - push!(leg01, leg) - end + eout !== nothing && push!(eout, leg) + if in_ej + D012 += leg_size else - push!(leg02, leg) + D01 += leg_size end else - if leg ∈ ei - if leg ∈ ej - push!(leg12, leg) - else - push!(leg1, leg) - end + eremove !== nothing && push!(eremove, leg) + if in_ej + D12 += leg_size else - push!(leg2, leg) + D1 += leg_size end end end - return LegInfo(leg1, leg2, leg12, leg01, leg02, leg012) -end - -function analyze_contraction_fast(incidence_list::IncidenceList{Int,ET}, vi::Int, vj::Int, log2_edge_sizes::AbstractDict{ET, T}) where {ET, T} - ei = edges(incidence_list, vi) - ej = edges(incidence_list, vj) - - D1 = D2 = D12 = D01 = D02 = D012 = zero(T) - - # external legs - for leg in ei ∪ ej - dim = log2_edge_sizes[leg] - isext = leg ∈ incidence_list.openedges || !all(x->x==vi || x==vj, vertices(incidence_list, leg)) - - if isext - if leg ∈ ei - if leg ∈ ej - D012 += dim - else - D01 += dim - end - else - D02 += dim - end - else - if leg ∈ ei - if leg ∈ ej - D12 += dim - else - D1 += dim - end + + # Process edges from vj that are not in vi + for leg in ej + if leg ∉ ei + isext = leg ∈ incidence_list.openedges || !all(x->x==vi || x==vj, vertices(incidence_list, leg)) + leg_size = log2_edge_sizes[leg] + + if isext + eout !== nothing && push!(eout, leg) + D02 += leg_size else - D2 += dim + eremove !== nothing && push!(eremove, leg) + D2 += leg_size end end end - - return D1, D2, D12, D01, D02, D012 + + return (D1, D2, D12, D01, D02, D012) end function greedy_loss(α, incidence_list, log2_edge_sizes, vi, vj) - D1, D2, D12, D01, D02, D012 = analyze_contraction_fast(incidence_list, vi, vj, log2_edge_sizes) + D1, D2, D12, D01, D02, D012 = compute_contraction_dims(incidence_list, log2_edge_sizes, vi, vj, nothing, nothing) loss = exp2(D01+D02+D012) - α * (exp2(D01+D12+D012) + exp2(D02+D12+D012)) # out - in return loss end @@ -264,48 +255,23 @@ end Greedy optimizing the contraction order and return a `NestedEinsum` object. Check the docstring of `tree_greedy` for detailed explaination of other input arguments. """ -function optimize_greedy(code::EinCode{L}, size_dict::Dict{L, T2}; α, temperature) where {L, T2} - optimize_greedy(getixsv(code), getiyv(code), size_dict; α, temperature) -end -function convert_label(ne::NestedEinsum, labelmap::Dict{T1,T2}) where {T1,T2} - isleaf(ne) && return NestedEinsum{T2}(ne.tensorindex) - eins = EinCode([getindex.(Ref(labelmap), ix) for ix in ne.eins.ixs], getindex.(Ref(labelmap), ne.eins.iy)) - NestedEinsum([convert_label(arg, labelmap) for arg in ne.args], eins) +function optimize_greedy(code::AbstractEinsum, size_dict::Dict{L, T2}; α, temperature) where {L, T2} + optimize_greedy_log2size(code, _log2_size_dict(size_dict); α, temperature) end -function optimize_greedy(ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, size_dict::Dict{L}; α, temperature) where {L} - if length(ixs) <= 2 - return NestedEinsum(NestedEinsum{L}.(1:length(ixs)), EinCode(ixs, iy)) - end - log2_edge_sizes = Dict{L,Float64}() - for (k, v) in size_dict - log2_edge_sizes[k] = log2(v) - end - incidence_list = IncidenceList(Dict([i=>ixs[i] for i=1:length(ixs)]); openedges=iy) - tree, _, _ = tree_greedy(incidence_list, log2_edge_sizes; α, temperature) - parse_eincode!(incidence_list, tree, 1:length(ixs), size_dict)[2] +function optimize_greedy_log2size(code::EinCode{L}, log2_size_dict::Dict{L}; α, temperature) where {L} + _optimize_greedy_log2size(getixsv(code), getiyv(code), log2_size_dict; α, temperature) end - -function optimize_greedy_log2(code::EinCode{L}, size_dict::Dict{L}, size_dict_log2::Dict{L}; α, temperature) where {L} - ixs = getixsv(code) - iy = getiyv(code) - +function _optimize_greedy_log2size(ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, log2_size_dict::Dict{L}; α, temperature) where {L} if length(ixs) <= 2 return NestedEinsum(NestedEinsum{L}.(1:length(ixs)), EinCode(ixs, iy)) end - incidence_list = IncidenceList(Dict([i=>ixs[i] for i=1:length(ixs)]); openedges=iy) - tree, _, _ = tree_greedy(incidence_list, size_dict_log2; α, temperature) - return parse_eincode!(incidence_list, tree, 1:length(ixs), size_dict)[2] + tree, _, _ = tree_greedy(incidence_list, log2_size_dict; α, temperature) + parse_eincode!(incidence_list, tree, 1:length(ixs), log2_size_dict)[2] end -function optimize_greedy(code::E, size_dict::AbstractDict{L}; α, temperature) where {E <: NestedEinsum, L} - size_dict_log2 = Dict{L, Float64}() - - for (lbl, dim) in size_dict - size_dict_log2[lbl] = log2(dim) - end - +function optimize_greedy_log2size(code::NestedEinsum{L}, log2_size_dict; α, temperature) where {L} # construct first-child next-sibling representation of `code` queue = [code] child = [0] @@ -343,14 +309,14 @@ function optimize_greedy(code::E, size_dict::AbstractDict{L}; α, temperature) w if isleaf(code) push!(queue, code) else - args = E[] + args = NestedEinsum{L}[] for _ in code.args push!(args, pop!(queue)) end if length(args) > 2 - code = replace_args(optimize_greedy_log2(code.eins, size_dict, size_dict_log2; α, temperature), args) + code = replace_args(optimize_greedy_log2size(code.eins, log2_size_dict; α, temperature), args) else code = NestedEinsum(args, code.eins) end diff --git a/src/incidencelist.jl b/src/incidencelist.jl index fd9d105..e49a5ef 100644 --- a/src/incidencelist.jl +++ b/src/incidencelist.jl @@ -18,7 +18,7 @@ function IncidenceList(v2e::Dict{VT,Vector{ET}}; openedges=ET[]) where {VT,ET} IncidenceList(v2e, e2v, openedges) end -Base.copy(il::IncidenceList) = IncidenceList(deepcopy(il.v2e), deepcopy(il.e2v), copy(il.openedges)) +Base.copy(il::IncidenceList) = IncidenceList(Dict([k=>copy(v) for (k,v) in il.v2e]), Dict([k=>copy(v) for (k,v) in il.e2v]), copy(il.openedges)) function neighbors(il::IncidenceList{VT}, v) where VT res = VT[] diff --git a/src/treewidth.jl b/src/treewidth.jl index d72b2a9..4a1c139 100644 --- a/src/treewidth.jl +++ b/src/treewidth.jl @@ -68,6 +68,7 @@ function optimize_treewidth(optimizer::Treewidth, code::AbstractEinsum, size_dic end function optimize_treewidth(optimizer::Treewidth, ixs::AbstractVector{<:AbstractVector}, iy::AbstractVector, size_dict::Dict{L, Int}; binary::Bool=true) where {L} + log2_size_dict = _log2_size_dict(size_dict) marker = zeros(Int, max(length(ixs) + 1, length(size_dict))) # construct incidence matrix `ve` @@ -86,7 +87,7 @@ function optimize_treewidth(optimizer::Treewidth, ixs::AbstractVector{<:Abstract if binary # binarize contraction tree - code = _optimize_code(code, size_dict, GreedyMethod()) + code = optimize_greedy_log2size(code, log2_size_dict; α = 0.0, temperature = 0.0) end return code diff --git a/src/utils.jl b/src/utils.jl index 08bf6c0..9ba6ee5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,3 +2,23 @@ function log2sumexp2(s) ms = maximum(s) return log2(sum(x->exp2(x - ms), s)) + ms end + +function _log2_size_dict(size_dict::Dict{L, T2}) where {L, T2} + log2_size_dict = Dict{L,Float64}() + for (k, v) in size_dict + log2_size_dict[k] = log2(v) + end + return log2_size_dict +end + +""" + convert_label(ne::NestedEinsum, labelmap::Dict{T1,T2}) where {T1,T2} + +Convert the labels of a `NestedEinsum` object to new labels. +`labelmap` is a dictionary that maps the old labels to the new labels. +""" +function convert_label(ne::NestedEinsum, labelmap::Dict{T1,T2}) where {T1,T2} + isleaf(ne) && return NestedEinsum{T2}(ne.tensorindex) + eins = EinCode([getindex.(Ref(labelmap), ix) for ix in ne.eins.ixs], getindex.(Ref(labelmap), ne.eins.iy)) + NestedEinsum([convert_label(arg, labelmap) for arg in ne.args], eins) +end \ No newline at end of file diff --git a/test/greedy.jl b/test/greedy.jl index 9b352d6..c270555 100644 --- a/test/greedy.jl +++ b/test/greedy.jl @@ -1,21 +1,11 @@ using OMEinsumContractionOrders -using OMEinsumContractionOrders: analyze_contraction, contract_pair!, evaluate_costs, contract_tree!, log2sumexp2, parse_tree -using OMEinsumContractionOrders: IncidenceList, analyze_contraction, LegInfo, tree_greedy, parse_eincode, optimize_greedy +using OMEinsumContractionOrders: contract_pair!, evaluate_costs, contract_tree!, log2sumexp2, parse_tree +using OMEinsumContractionOrders: IncidenceList, LegInfo, tree_greedy, parse_eincode, optimize_greedy +using OMEinsumContractionOrders: compute_contraction_dims using Graphs using Test, Random -@testset "analyze contraction" begin - incidence_list = IncidenceList(Dict(1 => [1, 2, 11, 15, 6], 2=>[1, 3, 4, 13, 6], 3=>[2, 3, 5, 6], 4=>[5], 5=>[4, 6]), openedges=[3, 6, 15]) - info = analyze_contraction(incidence_list, 1, 2) - @test Set(info.l1) == Set([11]) - @test Set(info.l2) == Set([13]) - @test Set(info.l12) == Set([1]) - @test Set(info.l01) == Set([2,15]) - @test Set(info.l02) == Set([3, 4]) - @test Set(info.l012) == Set([6]) -end - @testset "parse eincode" begin incidence_list = IncidenceList(Dict(1 => [1, 2], 2=>[1, 3, 4], 3=>[2, 3, 5, 6], 4=>[5], 5=>[4, 6])) tree = OMEinsumContractionOrders.ContractionTree(OMEinsumContractionOrders.ContractionTree(1, 2), OMEinsumContractionOrders.ContractionTree(3, 4)) @@ -155,5 +145,126 @@ end cc = contraction_complexity(optcode, size_dict) push!(sc_list, cc.sc) end - @test length(unique!(sc_list)) > 1 + @test minimum(sc_list) == 5 +end + +@testset "greedy_loss optimization" begin + function analyze_contraction(incidence_list::IncidenceList{Int,ET}, vi::Int, vj::Int) where {ET} + ei = OMEinsumContractionOrders.edges(incidence_list, vi) + ej = OMEinsumContractionOrders.edges(incidence_list, vj) + leg012,leg12,leg1,leg2,leg01,leg02 = ET[], ET[], ET[], ET[], ET[], ET[] + # external legs + for leg in ei ∪ ej + isext = leg ∈ incidence_list.openedges || !all(x->x==vi || x==vj, OMEinsumContractionOrders.vertices(incidence_list, leg)) + if isext + if leg ∈ ei + if leg ∈ ej + push!(leg012, leg) + else + push!(leg01, leg) + end + else + push!(leg02, leg) + end + else + if leg ∈ ei + if leg ∈ ej + push!(leg12, leg) + else + push!(leg1, leg) + end + else + push!(leg2, leg) + end + end + end + return LegInfo(leg1, leg2, leg12, leg01, leg02, leg012) + end + @testset "analyze contraction" begin + incidence_list = IncidenceList(Dict(1 => [1, 2, 11, 15, 6], 2=>[1, 3, 4, 13, 6], 3=>[2, 3, 5, 6], 4=>[5], 5=>[4, 6]), openedges=[3, 6, 15]) + info = analyze_contraction(incidence_list, 1, 2) + @test Set(info.l1) == Set([11]) + @test Set(info.l2) == Set([13]) + @test Set(info.l12) == Set([1]) + @test Set(info.l01) == Set([2,15]) + @test Set(info.l02) == Set([3, 4]) + @test Set(info.l012) == Set([6]) + end + + @testset "compute_contraction_dims" begin + # Test that compute_contraction_dims matches analyze_contraction + + # Create test incidence list + incidence_list = IncidenceList( + Dict(1 => [1, 2, 11, 15, 6], 2 => [1, 3, 4, 13, 6], 3 => [2, 3, 5, 6], 4 => [5], 5 => [4, 6]), + openedges=[3, 6, 15] + ) + log2_edge_sizes = Dict(i => Float64(i % 3 + 1) for i in [1,2,3,4,5,6,11,13,15]) + + # Test various vertex pairs + for (vi, vj) in [(1, 2), (2, 3), (1, 3)] + if vi in keys(incidence_list.v2e) && vj in keys(incidence_list.v2e) + # Get dimensions using the new function + D1, D2, D12, D01, D02, D012 = compute_contraction_dims(incidence_list, log2_edge_sizes, vi, vj, nothing, nothing) + eout, eremove = Int[], Int[] + D1_, D2_, D12_, D01_, D02_, D012_ = compute_contraction_dims(incidence_list, log2_edge_sizes, vi, vj, eout, eremove) + @test D1_ == D1 + @test D2_ == D2 + @test D12_ == D12 + @test D01_ == D01 + @test D02_ == D02 + @test D012_ == D012 + + # Compare with analyze_contraction + legs = analyze_contraction(incidence_list, vi, vj) + log2dim(legs_list) = isempty(legs_list) ? 0.0 : sum(l->log2_edge_sizes[l], legs_list) + D1_ref = log2dim(legs.l1) + D2_ref = log2dim(legs.l2) + D12_ref = log2dim(legs.l12) + D01_ref = log2dim(legs.l01) + D02_ref = log2dim(legs.l02) + D012_ref = log2dim(legs.l012) + + @test D1 ≈ D1_ref + @test D2 ≈ D2_ref + @test D12 ≈ D12_ref + @test D01 ≈ D01_ref + @test D02 ≈ D02_ref + @test D012 ≈ D012_ref + + # Verify edge lists + eout_ref = legs.l01 ∪ legs.l02 ∪ legs.l012 + eremove_ref = legs.l1 ∪ legs.l2 ∪ legs.l12 + @test Set(eout) == Set(eout_ref) + @test Set(eremove) == Set(eremove_ref) + end + end + end + + # Original implementation using analyze_contraction (for comparison) + function greedy_loss_with_vectors(α, incidence_list, log2_edge_sizes, vi, vj) + log2dim(legs) = isempty(legs) ? 0 : sum(l->log2_edge_sizes[l], legs) + legs = analyze_contraction(incidence_list, vi, vj) + D1, D2, D12, D01, D02, D012 = log2dim.(getfield.(Ref(legs), 1:6)) + return exp2(D01+D02+D012) - α * (exp2(D01+D12+D012) + exp2(D02+D12+D012)) + end + + # Create a simple tensor network + code = OMEinsumContractionOrders.EinCode([[1,2], [2,3], [3,4], [4,5], [5,1]], Int[]) + ixs = OMEinsumContractionOrders.getixsv(code) + iy = OMEinsumContractionOrders.getiyv(code) + size_dict = Dict([i=>2 for i in 1:5]) + log2_edge_sizes = Dict([i=>log2(size_dict[i]) for i in keys(size_dict)]) + incidence_list = IncidenceList(Dict([i=>ixs[i] for i=1:length(ixs)]); openedges=iy) + + # Compare optimized vs original implementation + for α in [0.0, 0.5, 1.0] + for vi in keys(incidence_list.v2e), vj in keys(incidence_list.v2e) + if vi < vj + loss_new = OMEinsumContractionOrders.greedy_loss(α, incidence_list, log2_edge_sizes, vi, vj) + loss_old = greedy_loss_with_vectors(α, incidence_list, log2_edge_sizes, vi, vj) + @test loss_new ≈ loss_old + end + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 25afeca..eab1baf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,10 @@ end include("Core.jl") end +@testset "utils" begin + include("utils.jl") +end + @testset "greedy" begin include("greedy.jl") end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..a8dbe23 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,34 @@ +using OMEinsumContractionOrders +using OMEinsumContractionOrders: log2sumexp2, _log2_size_dict, convert_label +using Test + +@testset "log2sumexp2" begin + @test log2sumexp2([1.0, 2.0, 3.0]) ≈ log2(exp2(1.0) + exp2(2.0) + exp2(3.0)) + @test log2sumexp2([10.0, 10.0]) ≈ 11.0 + @test log2sumexp2([0.0]) ≈ 0.0 +end + +@testset "_log2_size_dict" begin + size_dict = Dict('a' => 2, 'b' => 4, 'c' => 8) + log2_dict = _log2_size_dict(size_dict) + @test log2_dict['a'] ≈ 1.0 + @test log2_dict['b'] ≈ 2.0 + @test log2_dict['c'] ≈ 3.0 +end + +@testset "convert_label" begin + # Single leaf + ne = OMEinsumContractionOrders.NestedEinsum{Char}(1) + labelmap = Dict{Char, Int}() + result = convert_label(ne, labelmap) + @test result.tensorindex == 1 + + # Simple contraction + ne = OMEinsumContractionOrders.NestedEinsum([OMEinsumContractionOrders.NestedEinsum{Char}(1), OMEinsumContractionOrders.NestedEinsum{Char}(2)], + OMEinsumContractionOrders.EinCode([['a','b'], ['b','c']], ['a','c'])) + labelmap = Dict('a' => 1, 'b' => 2, 'c' => 3) + result = convert_label(ne, labelmap) + @test result.eins.ixs == [[1,2], [2,3]] + @test result.eins.iy == [1,3] +end +