From 44c5eebd55a7c40a7a9c9a6ad4918b9327530cce Mon Sep 17 00:00:00 2001 From: Gianluca Grosso Date: Sun, 29 Dec 2024 16:55:32 +0100 Subject: [PATCH 01/11] Correct minor bugs and dependency from Quantics.jl --- Project.toml | 3 +- src/PartitionedMPSs.jl | 2 +- src/adaptivemul.jl | 6 +- src/contract.jl | 6 +- src/partitionedmps.jl | 32 +++++----- src/patching.jl | 48 +++++++------- src/subdomainmps.jl | 53 ++++++++-------- src/util.jl | 120 +++++++++++++++++++++++++++++++++++ test/partitionedmps_tests.jl | 4 +- test/patching_tests.jl | 11 ++-- test/projector_tests.jl | 6 -- test/subdomainmps_tests.jl | 20 ++++-- 12 files changed, 222 insertions(+), 89 deletions(-) diff --git a/Project.toml b/Project.toml index c52805f..8d805c0 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,6 @@ julia = "1.6" [extras] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Quantics = "87f76fb3-a40a-40c9-a63c-29fcfe7b7547" [targets] -test = ["Test", "Random", "Quantics"] +test = ["Test", "Random"] diff --git a/src/PartitionedMPSs.jl b/src/PartitionedMPSs.jl index d9299b7..76c90a5 100644 --- a/src/PartitionedMPSs.jl +++ b/src/PartitionedMPSs.jl @@ -4,7 +4,7 @@ import OrderedCollections: OrderedSet, OrderedDict using EllipsisNotation using LinearAlgebra: LinearAlgebra -import ITensors: ITensors, Index, ITensor, dim, inds, qr, commoninds, uniqueinds +import ITensors: ITensors, Index, ITensor, dim, inds, qr, commoninds, uniqueinds, hasplev import ITensorMPS: ITensorMPS, AbstractMPS, MPS, MPO, siteinds, findsites import ITensors.TagSets: hastag, hastags diff --git a/src/adaptivemul.jl b/src/adaptivemul.jl index 669d612..f77c596 100644 --- a/src/adaptivemul.jl +++ b/src/adaptivemul.jl @@ -74,9 +74,9 @@ function adaptivecontract( result_blocks = SubDomainMPS[] for (p, muls) in patches - prjmpss = [contract(m.a, m.b; alg, cutoff, maxdim, kwargs...) for m in muls] - #patches[p] = +(prjmpss...; alg="fit", cutoff, maxdim) - push!(result_blocks, +(prjmpss...; alg="fit", cutoff, maxdim)) + subdmps = [contract(m.a, m.b; alg, cutoff, maxdim, kwargs...) for m in muls] + #patches[p] = +(subdmps...; alg="fit", cutoff, maxdim) + push!(result_blocks, +(subdmps...; alg="fit", cutoff, maxdim)) end return PartitionedMPS(result_blocks) diff --git a/src/contract.jl b/src/contract.jl index eef6d02..9c819e1 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -72,12 +72,12 @@ function projcontract( _, external_sites = _projector_after_contract(M1, M2) if !_is_externalsites_compatible_with_projector(external_sites, proj) - error("The projector contains projection onto a site what is not a external sites.") + error("The projector contains projection onto a site that is not an external site.") end - t1 = time_ns() + # t1 = time_ns() r = contract(M1, M2; alg, cutoff, maxdim, kwargs...) - t2 = time_ns() + # t2 = time_ns() #println("contract: $((t2 - t1)*1e-9) s") return r end diff --git a/src/partitionedmps.jl b/src/partitionedmps.jl index 10b265f..c7835b5 100644 --- a/src/partitionedmps.jl +++ b/src/partitionedmps.jl @@ -6,11 +6,11 @@ struct PartitionedMPS data::OrderedDict{Projector,SubDomainMPS} function PartitionedMPS(data::AbstractVector{SubDomainMPS}) - sites_all = [siteinds(prjmps) for prjmps in data] + sites_all = [siteinds(subdmps) for subdmps in data] for n in 2:length(data) Set(sites_all[n]) == Set(sites_all[1]) || error("Sitedims mismatch") end - isdisjoint([prjmps.projector for prjmps in data]) || error("Projectors are overlapping") + isdisjoint([subdmps.projector for subdmps in data]) || error("Projectors are overlapping") dict_ = OrderedDict{Projector,SubDomainMPS}( data[i].projector => data[i] for i in 1:length(data) @@ -58,19 +58,19 @@ Base.length(obj::PartitionedMPS) = length(obj.data) """ Indexing for PartitionedMPS. This is deprecated and will be removed in the future. """ -function Base.getindex(bmps::PartitionedMPS, i::Integer)::SubDomainMPS - @warn "Indexing for PartitionedMPS is deprecated. Use getindex(bmps, p::Projector) instead." - return first(Iterators.drop(values(bmps.data), i - 1)) +function Base.getindex(partmps::PartitionedMPS, i::Integer)::SubDomainMPS + @warn "Indexing for PartitionedMPS is deprecated. Use getindex(partmps, p::Projector) instead." + return first(Iterators.drop(values(partmps.data), i - 1)) end Base.getindex(obj::PartitionedMPS, p::Projector) = obj.data[p] -function Base.iterate(bmps::PartitionedMPS, state) - return iterate(bmps.data, state) +function Base.iterate(partmps::PartitionedMPS, state) + return iterate(partmps.data, state) end -function Base.iterate(bmps::PartitionedMPS) - return iterate(bmps.data) +function Base.iterate(partmps::PartitionedMPS) + return iterate(partmps.data) end """ @@ -92,11 +92,13 @@ Rearrange the site indices of the PartitionedMPS according to the given order. If nessecary, tensors are fused or split to match the new order. """ function rearrange_siteinds(obj::PartitionedMPS, sites) - return PartitionedMPS([rearrange_siteinds(prjmps, sites) for prjmps in values(obj)]) + return PartitionedMPS([rearrange_siteinds(subdmps, sites) for subdmps in values(obj)]) end function prime(Ψ::PartitionedMPS, args...; kwargs...) - return PartitionedMPS([prime(prjmps, args...; kwargs...) for prjmps in values(Ψ.data)]) + return PartitionedMPS([ + prime(subdmps, args...; kwargs...) for subdmps in values(Ψ.data) + ]) end """ @@ -234,7 +236,7 @@ end function ITensorMPS.MPO( obj::PartitionedMPS; cutoff=default_cutoff(), maxdim=default_maxdim() )::MPO - return MPO(collect(MPS(obj; cutoff=cutoff, maxdim=maxdim, kwargs...))) + return MPO(collect(MPS(obj; cutoff=cutoff, maxdim=maxdim))) end """ @@ -242,13 +244,13 @@ Make the PartitionedMPS diagonal for a given site index `s` by introducing a dum """ function makesitediagonal(obj::PartitionedMPS, site) return PartitionedMPS([ - _makesitediagonal(prjmps, site; baseplev=baseplev) for prjmps in values(obj) + _makesitediagonal(subdmps, site; baseplev=baseplev) for subdmps in values(obj) ]) end function _makesitediagonal(obj::PartitionedMPS, site; baseplev=0) return PartitionedMPS([ - _makesitediagonal(prjmps, site; baseplev=baseplev) for prjmps in values(obj) + _makesitediagonal(subdmps, site; baseplev=baseplev) for subdmps in values(obj) ]) end @@ -257,7 +259,7 @@ Extract diagonal of the PartitionedMPS for `s`, `s'`, ... for a given site index where `s` must have a prime level of 0. """ function extractdiagonal(obj::PartitionedMPS, site) - return PartitionedMPS([extractdiagonal(prjmps, site) for prjmps in values(obj)]) + return PartitionedMPS([extractdiagonal(subdmps, site) for subdmps in values(obj)]) end function dist(a::PartitionedMPS, b::PartitionedMPS) diff --git a/src/patching.jl b/src/patching.jl index e89a08d..d0cc728 100644 --- a/src/patching.jl +++ b/src/patching.jl @@ -5,35 +5,35 @@ If the bond dimension of the result reaches `maxdim`, perform patching recursively to reduce the bond dimension. """ function _add_patching( - prjmpss::AbstractVector{SubDomainMPS}; + subdmpss::AbstractVector{SubDomainMPS}; cutoff=0.0, maxdim=typemax(Int), alg="fit", patchorder=Index[], )::Vector{SubDomainMPS} - if length(unique([prjmps.projector for prjmps in prjmpss])) != 1 + if length(unique([sudmps.projector for sudmps in subdmpss])) != 1 error("All SubDomainMPS objects must have the same projector.") end # First perform addition upto given maxdim # TODO: Early termination if the bond dimension reaches maxdim - sum_approx = _add(prjmpss...; alg, cutoff, maxdim) + sum_approx = _add(subdmpss...; alg, cutoff, maxdim) # If the bond dimension is less than maxdim, return the result maxbonddim(sum_approx) < maxdim && return [sum_approx] - @assert maxbonddim(sum_approx) == maxdim + # @assert maxbonddim(sum_approx) == maxdim - nextprjidx = _next_projindex(prjmpss[1].projector, patchorder) + nextprjidx = _next_projindex(subdmpss[1].projector, patchorder) - nextprjidx === nothing && return PartitionedMPS(sum_approx) + nextprjidx === nothing && return [sum_approx] blocks = SubDomainMPS[] for prjval in 1:ITensors.dim(nextprjidx) - prj_ = prjmpss[1].projector & Projector(nextprjidx => prjval) + prj_ = subdmpss[1].projector & Projector(nextprjidx => prjval) blocks = blocks ∪ _add_patching( - [project(prjmps, prj_) for prjmps in prjmpss]; + [project(sudmps, prj_) for sudmps in subdmpss]; cutoff, maxdim, alg, @@ -60,13 +60,15 @@ end Add multiple PartitionedMPS objects. """ function add_patching( - bmpss::AbstractVector{PartitionedMPS}; + partmps::AbstractVector{PartitionedMPS}; cutoff=0.0, maxdim=typemax(Int), alg="fit", patchorder=Index[], )::PartitionedMPS - result = _add_patching(union(values(x) for x in bmpss); cutoff, maxdim, alg, patchorder) + result = _add_patching( + union(values(x) for x in partmps); cutoff, maxdim, alg, patchorder + ) return PartitionedMPS(result) end @@ -77,29 +79,31 @@ Do patching recursively to reduce the bond dimension. If the bond dimension of a SubDomainMPS exceeds `maxdim`, perform patching. """ function adaptive_patching( - prjmps::SubDomainMPS, patchorder; cutoff=0.0, maxdim=typemax(Int) + subdmps::SubDomainMPS, patchorder; cutoff=0.0, maxdim=typemax(Int) )::Vector{SubDomainMPS} - if maxbonddim(prjmps) <= maxdim - return [prjmps] + if maxbonddim(subdmps) <= maxdim + return [subdmps] end # If the bond dimension exceeds maxdim, perform patching - refined_prjmpss = SubDomainMPS[] - nextprjidx = _next_projindex(prjmps.projector, patchorder) + refined_subdmpss = SubDomainMPS[] + nextprjidx = _next_projindex(subdmps.projector, patchorder) if nextprjidx === nothing - return [prjmps] + return [subdmps] end for prjval in 1:ITensors.dim(nextprjidx) - prj_ = prjmps.projector & Projector(nextprjidx => prjval) - prjmps_ = truncate(project(prjmps, prj_); cutoff, maxdim) - if maxbonddim(prjmps_) <= maxdim - push!(refined_prjmpss, prjmps_) + prj_ = subdmps.projector & Projector(nextprjidx => prjval) + subdmps_ = truncate(project(subdmps, prj_); cutoff, maxdim) + if maxbonddim(subdmps_) <= maxdim + push!(refined_subdmpss, subdmps_) else - append!(refined_prjmpss, adaptive_patching(prjmps_, patchorder; cutoff, maxdim)) + append!( + refined_subdmpss, adaptive_patching(subdmps_, patchorder; cutoff, maxdim) + ) end end - return refined_prjmpss + return refined_subdmpss end """ diff --git a/src/subdomainmps.jl b/src/subdomainmps.jl index 431c531..1611bd7 100644 --- a/src/subdomainmps.jl +++ b/src/subdomainmps.jl @@ -150,9 +150,11 @@ function _add(ψ::AbstractMPS...; alg="fit", cutoff=1e-15, maxdim=typemax(Int), return +(ITensors.Algorithm(alg), ψ...) elseif alg == "densitymatrix" if cutoff < 1e-15 - @warn "Cutoff is very small, it may suffer from numerical round errors. The densitymatrix algorithm squares the singular values of the reduce density matrix. Please consider increasing it or using fit algorithm." + @warn "Cutoff is very small, it may suffer from numerical round errors. + The densitymatrix algorithm squares the singular values of the reduce density matrix. + Please consider increasing it or using fit algorithm." end - return +(ITensors.Algorithm"densitymatrix"(), ψ...; cutoff, maxdim, kwargs...) + return +(ITensors.Algorithm("densitymatrix"), ψ...; cutoff, maxdim, kwargs...) elseif alg == "fit" function f(x, y) return ITensors.truncate( @@ -211,39 +213,39 @@ function LinearAlgebra.norm(M::SubDomainMPS) end function _makesitediagonal( - SubDomainMPS::SubDomainMPS, sites::AbstractVector{Index{IndsT}}; baseplev=0 + obj::SubDomainMPS, sites::AbstractVector{Index{IndsT}}; baseplev=0 ) where {IndsT} - M_ = deepcopy(MPO(collect(MPS(SubDomainMPS)))) + M_ = deepcopy(MPO(collect(MPS(obj)))) for site in sites target_site::Int = only(findsites(M_, site)) M_[target_site] = _asdiagonal(M_[target_site], site; baseplev=baseplev) end - return project(M_, SubDomainMPS.projector) + return project(M_, obj.projector) end -function _makesitediagonal(SubDomainMPS::SubDomainMPS, site::Index; baseplev=0) - return _makesitediagonal(SubDomainMPS, [site]; baseplev=baseplev) +function _makesitediagonal(obj::SubDomainMPS, site::Index; baseplev=0) + return _makesitediagonal(obj, [site]; baseplev=baseplev) end -function makesitediagonal(SubDomainMPS::SubDomainMPS, site::Index) - return _makesitediagonal(SubDomainMPS, site; baseplev=0) +function makesitediagonal(obj::SubDomainMPS, site::Index) + return _makesitediagonal(obj, site; baseplev=0) end -function makesitediagonal(SubDomainMPS::SubDomainMPS, sites::AbstractVector{Index}) - return _makesitediagonal(SubDomainMPS, sites; baseplev=0) +function makesitediagonal(obj::SubDomainMPS, sites::AbstractVector{Index}) + return _makesitediagonal(obj, sites; baseplev=0) end -function makesitediagonal(SubDomainMPS::SubDomainMPS, tag::String) - mps_diagonal = makesitediagonal(MPS(SubDomainMPS), tag) +function makesitediagonal(obj::SubDomainMPS, tag::String) + mps_diagonal = makesitediagonal(MPS(obj), tag) SubDomainMPS_diagonal = SubDomainMPS(mps_diagonal) target_sites = findallsiteinds_by_tag( - unique(ITensors.noprime.(Iterators.flatten(siteinds(SubDomainMPS)))); tag=tag + unique(ITensors.noprime.(Iterators.flatten(siteinds(obj)))); tag=tag ) - newproj = deepcopy(SubDomainMPS.projector) + newproj = deepcopy(obj.projector) for s in target_sites - if isprojectedat(SubDomainMPS.projector, s) + if isprojectedat(obj.projector, s) newproj[ITensors.prime(s)] = newproj[s] end end @@ -252,7 +254,10 @@ function makesitediagonal(SubDomainMPS::SubDomainMPS, tag::String) end # FIXME: may be type unstable -function _find_site_allplevs(tensor::ITensor, site::Index; maxplev=10) +# Gianluca: FIXED (?) +function _find_site_allplevs( + tensor::ITensor, site::Index{T}; maxplev=10 +)::Vector{Index{T}} where {T} ITensors.plev(site) == 0 || error("Site index must be unprimed.") return [ ITensors.prime(site, plev) for @@ -261,9 +266,9 @@ function _find_site_allplevs(tensor::ITensor, site::Index; maxplev=10) end function extractdiagonal( - SubDomainMPS::SubDomainMPS, sites::AbstractVector{Index{IndsT}} + obj::SubDomainMPS, sites::AbstractVector{Index{IndsT}} ) where {IndsT} - tensors = collect(SubDomainMPS.data) + tensors = collect(obj.data) for i in eachindex(tensors) for site in intersect(sites, ITensors.inds(tensors[i])) sitewithallplevs = _find_site_allplevs(tensors[i], site) @@ -275,7 +280,7 @@ function extractdiagonal( end end - projector = deepcopy(SubDomainMPS.projector) + projector = deepcopy(obj.projector) for site in sites if site' in keys(projector.data) delete!(projector.data, site') @@ -284,9 +289,7 @@ function extractdiagonal( return SubDomainMPS(MPS(tensors), projector) end -function extractdiagonal(SubDomainMPS::SubDomainMPS, tag::String)::SubDomainMPS - targetsites = findallsiteinds_by_tag( - unique(ITensors.noprime.(PartitionedMPSs._allsites(SubDomainMPS))); tag=tag - ) - return extractdiagonal(SubDomainMPS, targetsites) +function extractdiagonal(obj::SubDomainMPS, tag::String)::SubDomainMPS + targetsites = findallsiteinds_by_tag(unique(ITensors.noprime.(_allsites(obj))); tag=tag) + return extractdiagonal(obj, targetsites) end diff --git a/src/util.jl b/src/util.jl index e8c05ef..7afdc28 100644 --- a/src/util.jl +++ b/src/util.jl @@ -105,3 +105,123 @@ function rearrange_siteinds(M::AbstractMPS, sites::Vector{Vector{Index{T}}})::MP tensors[end] *= t return MPS(tensors) end + +# A valid tag should not contain "=". +_valid_tag(tag::String)::Bool = !occursin("=", tag) + +""" +Find sites with the given tag + +For tag = `x`, if `sites` contains an Index object with `x`, the function returns a vector containing only its positon. + +If not, the function seach for all Index objects with tags `x=1`, `x=2`, ..., and return their positions. + +If no Index object is found, an empty vector will be returned. +""" +function findallsites_by_tag( + sites::Vector{Index{T}}; tag::String="x", maxnsites::Int=1000 +)::Vector{Int} where {T} + _valid_tag(tag) || error("Invalid tag: $tag") + + # 1) Check if there is an Index with exactly `tag` + idx = findall(hastags(tag), sites) + if !isempty(idx) + if length(idx) > 1 + error("Found more than one site index with tag $(tag)!") + end + return idx + end + + # 2) If not found, search for tag=1, tag=2, ... + result = Int[] + for n in 1:maxnsites + tag_ = tag * "=$n" + idx = findall(hastags(tag_), sites) + if length(idx) == 0 + break + elseif length(idx) > 1 + error("Found more than one site indices with $(tag_)!") + end + push!(result, idx[1]) + end + return result +end + +function findallsiteinds_by_tag( + sites::AbstractVector{Index{T}}; tag::String="x", maxnsites::Int=1000 +) where {T} + _valid_tag(tag) || error("Invalid tag: $tag") + positions = findallsites_by_tag(sites; tag=tag, maxnsites=maxnsites) + return [sites[p] for p in positions] +end + +function findallsites_by_tag( + sites::Vector{Vector{Index{T}}}; tag::String="x", maxnsites::Int=1000 +)::Vector{NTuple{2,Int}} where {T} + _valid_tag(tag) || error("Invalid tag: $tag") + + sites_dict = Dict{Index{T},NTuple{2,Int}}() + for i in 1:length(sites) + for j in 1:length(sites[i]) + sites_dict[sites[i][j]] = (i, j) + end + end + + sitesflatten = collect(Iterators.flatten(sites)) + + idx_exact = findall(i -> hastags(i, tag) && hasplev(i, 0), sitesflatten) + if !isempty(idx_exact) + if length(idx_exact) > 1 + error("Found more than one site index with tag '$tag'!") + end + # Return a single position + return [sites_dict[sitesflatten[only(idx_exact)]]] + end + + result = NTuple{2,Int}[] + for n in 1:maxnsites + tag_ = tag * "=$n" + idx = findall(i -> hastags(i, tag_) && hasplev(i, 0), sitesflatten) + if length(idx) == 0 + break + elseif length(idx) > 1 + error("Found more than one site indices with $(tag_)!") + end + + push!(result, sites_dict[sitesflatten[only(idx)]]) + end + return result +end + +function findallsiteinds_by_tag( + sites::Vector{Vector{Index{T}}}; tag::String="x", maxnsites::Int=1000 +)::Vector{Index{T}} where {T} + _valid_tag(tag) || error("Invalid tag: $tag") + positions = findallsites_by_tag(sites; tag=tag, maxnsites=maxnsites) + return [sites[i][j] for (i, j) in positions] +end + +function makesitediagonal(M::AbstractMPS, tag::String)::MPS + M_ = deepcopy(MPO(collect(M))) + sites = siteinds(M_) + + target_positions = findallsites_by_tag(siteinds(M_); tag=tag) + + for t in eachindex(target_positions) + i, j = target_positions[t] + M_[i] = _asdiagonal(M_[i], sites[i][j]) + end + + return MPS(collect(M_)) +end + +function _extract_diagonal(t, site::Index{T}, site2::Index{T}) where {T<:Number} + dim(site) == dim(site2) || error("Dimension mismatch") + restinds = uniqueinds(inds(t), site, site2) + newdata = zeros(eltype(t), dim.(restinds)..., dim(site)) + olddata = Array(t, restinds..., site, site2) + for i in 1:dim(site) + newdata[.., i] = olddata[.., i, i] + end + return ITensor(newdata, restinds..., site) +end diff --git a/test/partitionedmps_tests.jl b/test/partitionedmps_tests.jl index cdd1deb..d60bb3e 100644 --- a/test/partitionedmps_tests.jl +++ b/test/partitionedmps_tests.jl @@ -30,8 +30,8 @@ import PartitionedMPSs: PartitionedMPSs, Projector, project, SubDomainMPS, Parti @test length([(k, v) for (k, v) in PartitionedMPS(prjΨ1)]) == 1 Ψreconst = PartitionedMPS(prjΨ1) + PartitionedMPS(prjΨ2) - @test Ψreconst[1] ≈ prjΨ1 - @test Ψreconst[2] ≈ prjΨ2 + @test Ψreconst[Projector(sitesx[1] => 1)] ≈ prjΨ1 + @test Ψreconst[Projector(sitesx[1] => 2)] ≈ prjΨ2 @test MPS(Ψreconst) ≈ Ψ @test ITensors.norm(Ψreconst) ≈ ITensors.norm(MPS(Ψreconst)) diff --git a/test/patching_tests.jl b/test/patching_tests.jl index d6b6ba8..d8f5731 100644 --- a/test/patching_tests.jl +++ b/test/patching_tests.jl @@ -15,14 +15,15 @@ using Random sites = collect(collect.(zip(sitesx, sitesy))) - prjmps = SubDomainMPS(_random_mpo(sites; linkdims=20)) + subdmps = SubDomainMPS(_random_mpo(sites; linkdims=20)) sites_ = collect(Iterators.flatten(sites)) - bmps = PartitionedMPS(adaptive_patching(prjmps, sites_; maxdim=10, cutoff=1e-25)) + partmps = PartitionedMPS( + adaptive_patching(subdmps, sites_; maxdim=10, cutoff=1e-25) + ) - @test length(values((bmps))) > 1 + @test length(values((partmps))) > 1 - @test MPS(bmps) ≈ MPS(prjmps) rtol = 1e-12 - #MPS(bmps) ≈ MPS(prjmps) + @test MPS(partmps) ≈ MPS(subdmps) rtol = 1e-12 end end diff --git a/test/projector_tests.jl b/test/projector_tests.jl index 459018b..7712014 100644 --- a/test/projector_tests.jl +++ b/test/projector_tests.jl @@ -52,12 +52,6 @@ import PartitionedMPSs: Projector, hasoverlap @test p1 & p2 == Projector(Dict(inds[1] => 2, inds[2] => 1)) end - let - p1 = Projector(Dict(inds[2] => 1)) - p2 = Projector(Dict(inds[1] => 2)) - @test p1 & p2 == Projector(Dict(inds[1] => 2, inds[2] => 1)) - end - let p1 = Projector(Dict(inds[2] => 1, inds[3] => 1)) p2 = Projector(Dict(inds[1] => 2, inds[3] => 1)) diff --git a/test/subdomainmps_tests.jl b/test/subdomainmps_tests.jl index 3fa1cc0..6e327a9 100644 --- a/test/subdomainmps_tests.jl +++ b/test/subdomainmps_tests.jl @@ -5,7 +5,13 @@ using ITensors using Random import PartitionedMPSs: - PartitionedMPSs, Projector, project, SubDomainMPS, rearrange_siteinds + PartitionedMPSs, + Projector, + project, + SubDomainMPS, + rearrange_siteinds, + makesitediagonal, + extractdiagonal @testset "subdomainmps.jl" begin @testset "SubDomainMPS" begin @@ -27,7 +33,6 @@ import PartitionedMPSs: @test Ψreconst ≈ Ψ end - #== @testset "rearrange_siteinds" begin N = 3 sitesx = [Index(2, "x=$n") for n in 1:N] @@ -75,6 +80,9 @@ import PartitionedMPSs: @test extractdiagonal(prjΨ1_diagonalz, "y") ≈ prjΨ1 + diag_ok = true + offdiag_ok = true + for indval in eachindval(sites_diagonalz...) ind = first.(indval) val = last.(indval) @@ -94,11 +102,13 @@ import PartitionedMPSs: if isdiagonalelement nondiaginds = unique(noprime(i) => v for (i, v) in indval) - @test psi_diag[indval...] == psi[nondiaginds...] + diag_ok = diag_ok && (psi_diag[indval...] == psi[nondiaginds...]) else - @test iszero(psi_diag[indval...]) + offdiag_ok = offdiag_ok && iszero(psi_diag[indval...]) end end + + @test diag_ok + @test offdiag_ok end - ==# end From a58e88efeadfedc4ef7cb3a931334c5f940dc194 Mon Sep 17 00:00:00 2001 From: Gianluca Grosso Date: Sun, 2 Mar 2025 04:45:49 +0100 Subject: [PATCH 02/11] Implement elementwise product, bug fixes and reformat --- src/PartitionedMPSs.jl | 3 +- src/automul.jl | 115 +++++++++++++++++++++++++++++++++++ src/bak/conversion.jl | 58 ++++++++++++++++++ src/contract.jl | 16 ++--- src/partitionedmps.jl | 18 ++++-- src/subdomainmps.jl | 54 ++++++++-------- src/util.jl | 87 ++++++++++++++++++++++---- test/_util.jl | 4 ++ test/automul_tests.jl | 112 ++++++++++++++++++++++++++++++++++ test/bak/conversion_tests.jl | 2 +- test/contract_tests.jl | 3 +- test/runtests.jl | 5 +- 12 files changed, 421 insertions(+), 56 deletions(-) create mode 100644 src/automul.jl create mode 100644 src/bak/conversion.jl create mode 100644 test/automul_tests.jl diff --git a/src/PartitionedMPSs.jl b/src/PartitionedMPSs.jl index 76c90a5..8b9ce86 100644 --- a/src/PartitionedMPSs.jl +++ b/src/PartitionedMPSs.jl @@ -5,7 +5,7 @@ using EllipsisNotation using LinearAlgebra: LinearAlgebra import ITensors: ITensors, Index, ITensor, dim, inds, qr, commoninds, uniqueinds, hasplev -import ITensorMPS: ITensorMPS, AbstractMPS, MPS, MPO, siteinds, findsites +import ITensorMPS: ITensorMPS, AbstractMPS, MPS, MPO, siteinds, findsites, findsite import ITensors.TagSets: hastag, hastags import FastMPOContractions as FMPOC @@ -20,5 +20,6 @@ include("partitionedmps.jl") include("patching.jl") include("contract.jl") include("adaptivemul.jl") +include("automul.jl") end diff --git a/src/automul.jl b/src/automul.jl new file mode 100644 index 0000000..4c96bc3 --- /dev/null +++ b/src/automul.jl @@ -0,0 +1,115 @@ +@doc raw""" + function elemmul( + M1::PartitionedMPS, + M2::PartitionedMPS + ) + +Performs elementwise multiplication between partitioned MPSs. Element-wise product is defined +as: + +```math + (fg) (\xi) = f(\xi)g(\xi) = \sum_{\xi'} f(\xi, \xi') g(\xi') +``` + +where ``f(\xi, \xi') = f(\xi) \delta_{\xi, \xi'}``. +""" +function elemmul( + M1::PartitionedMPS, + M2::PartitionedMPS; + alg="zipup", + maxdim=typemax(Int), + cutoff=1e-25, + kwargs..., +) + all(length.(siteinds(M1)) .== 1) || error("M1 should have only 1 site index per site") + all(length.(siteinds(M2)) .== 1) || error("M2 should have only 1 site index per site") + + only.(siteinds(M1)) == only.(siteinds(M2)) || + error("Sites for element wise multiplication should be identical") + sites_ewmul = only.(siteinds(M1)) + + M1 = makesitediagonal(M1, sites_ewmul; baseplev=1) + M2 = makesitediagonal(M2, sites_ewmul; baseplev=0) + + M = contract(M1, M2; alg=alg, kwargs...) + + M = extractdiagonal(M, sites_ewmul) + + return truncate(M; cutoff=cutoff, maxdim=maxdim) +end + +@doc raw""" + function automul( + M1::PartitionedMPS, + M2::PartitionedMPS; + tag_row::String="", + tag_shared::String="", + tag_col::String="", + ... +) + +Performs automatic multiplication between partitioned MPSs. Automatic multiplication is defined +as: + +```math + (fg) (\sigma_{row}, \sigma_{col}; \xi) = \sum_{\sigma_{shared}} + f(\sigma_{row}, \sigma_{shared}; \xi) g(\sigma_{shared}, \sigma_{col} ; \xi). +``` + +By default, only element-wise product on sites ``\xi`` will be performed. See also: [`elemmul`](@ref). +""" +function automul( + M1::PartitionedMPS, + M2::PartitionedMPS; + tag_row::String="", + tag_shared::String="", + tag_col::String="", + alg="zipup", + maxdim=typemax(Int), + cutoff=1e-25, + kwargs..., +) + all(length.(siteinds(M1)) .== 1) || error("M1 should have only 1 site index per site") + all(length.(siteinds(M2)) .== 1) || error("M2 should have only 1 site index per site") + + sites_row = _findallsiteinds_by_tag(M1; tag=tag_row) + sites_shared = _findallsiteinds_by_tag(M1; tag=tag_shared) + sites_col = _findallsiteinds_by_tag(M2; tag=tag_col) + sites_matmul = Set(Iterators.flatten([sites_row, sites_shared, sites_col])) + + sites1 = only.(siteinds(M1)) + sites1_ewmul = setdiff(only.(siteinds(M1)), sites_matmul) + sites2_ewmul = setdiff(only.(siteinds(M2)), sites_matmul) + sites2_ewmul == sites1_ewmul || error("Invalid sites for elementwise multiplication") + + M1 = makesitediagonal(M1, sites1_ewmul; baseplev=1) + M2 = makesitediagonal(M2, sites2_ewmul; baseplev=0) + + sites_M1_diag = [collect(x) for x in siteinds(M1)] + sites_M2_diag = [collect(x) for x in siteinds(M2)] + + M1 = rearrange_siteinds(M1, combinesites(sites_M1_diag, sites_row, sites_shared)) + + M2 = rearrange_siteinds(M2, combinesites(sites_M2_diag, sites_shared, sites_col)) + + M = contract(M1, M2; alg=alg, kwargs...) + + M = extractdiagonal(M, sites1_ewmul) + + ressites = Vector{eltype(siteinds(M1)[1])}[] + for s in siteinds(M) + s_ = unique(ITensors.noprime.(s)) + if length(s_) == 1 + push!(ressites, s_) + else + if s_[1] ∈ sites1 + push!(ressites, [s_[1]]) + push!(ressites, [s_[2]]) + else + push!(ressites, [s_[2]]) + push!(ressites, [s_[1]]) + end + end + end + return truncate(rearrange_siteinds(M, ressites); cutoff=cutoff, maxdim=maxdim) +end diff --git a/src/bak/conversion.jl b/src/bak/conversion.jl new file mode 100644 index 0000000..10871f5 --- /dev/null +++ b/src/bak/conversion.jl @@ -0,0 +1,58 @@ +""" +Conversion from a `TCIAlgorithms.ProjTensorTrain` to a `SubDomainMPS`. +""" +function SubDomainMPS(projtt::TCIA.ProjTensorTrain{T}, sites)::SubDomainMPS where {T} + links = [Index(ld, "Link,l=$l") for (l, ld) in enumerate(TCI.linkdims(projtt.data))] + + tensors = ITensor[] + sitedims = [collect(dim.(s)) for s in sites] + linkdims = dim.(links) + + push!( + tensors, + ITensor( + reshape(projtt.data[1], 1, prod(sitedims[1]), linkdims[1]), + sites[1]..., + links[1], + ), + ) + + for n in 2:(length(projtt.data) - 1) + push!( + tensors, + ITensor( + reshape(projtt.data[n], linkdims[n - 1], prod(sitedims[n]), linkdims[n]), + links[n - 1], + sites[n]..., + links[n], + ), + ) + end + + push!( + tensors, + ITensor( + reshape(projtt.data[end], linkdims[end], prod(sitedims[end])), + links[end], + sites[end]..., + ), + ) + + proj = Dict{Index,Int}() + for i in eachindex(projtt.projector.data) + for j in eachindex(projtt.projector.data[i]) + if projtt.projector.data[i][j] > 0 + proj[sites[i][j]] = projtt.projector.data[i][j] + end + end + end + + return SubDomainMPS(MPS(tensors), Projector(proj)) +end + +""" +Conversion from a `TCIAlgorithms.Proj` to a `PartitionedMPS`. +""" +function PartitionedMPS(obj::TCIA.ProjTTContainer{T}, sites)::PartitionedMPS where {T} + return PartitionedMPS([SubDomainMPS(x, sites) for x in obj.data]) +end diff --git a/src/contract.jl b/src/contract.jl index 9c819e1..35e5394 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -56,7 +56,7 @@ function projcontract( M1::SubDomainMPS, M2::SubDomainMPS, proj::Projector; - alg="fit", + alg="zipup", cutoff=default_cutoff(), maxdim=default_maxdim(), verbosity=0, @@ -90,7 +90,7 @@ function projcontract( M1::AbstractVector{SubDomainMPS}, M2::AbstractVector{SubDomainMPS}, proj::Projector; - alg="fit", + alg="zipup", alg_sum="fit", cutoff=default_cutoff(), maxdim=default_maxdim(), @@ -139,7 +139,7 @@ At each site, the objects must share at least one site index. function contract( M1::PartitionedMPS, M2::PartitionedMPS; - alg="fit", + alg="zipup", cutoff=default_cutoff(), maxdim=default_maxdim(), patchorder=Index[], @@ -158,7 +158,7 @@ function contract!( M::PartitionedMPS, M1::PartitionedMPS, M2::PartitionedMPS; - alg="fit", + alg="zipup", cutoff=default_cutoff(), maxdim=default_maxdim(), patchorder=Index[], @@ -179,10 +179,10 @@ function contract!( if haskey(M.data, b) && !overwrite continue end - res::Vector{SubDomainMPS} = projcontract( - M1_, M2_, b; alg, cutoff, maxdim, patchorder, kwargs... - ) - append!(M, res) + res = projcontract(M1_, M2_, b; alg, cutoff, maxdim, patchorder, kwargs...) + if res !== nothing + append!(M, res) + end end return M end diff --git a/src/partitionedmps.jl b/src/partitionedmps.jl index c7835b5..da653bb 100644 --- a/src/partitionedmps.jl +++ b/src/partitionedmps.jl @@ -242,15 +242,17 @@ end """ Make the PartitionedMPS diagonal for a given site index `s` by introducing a dummy index `s'`. """ -function makesitediagonal(obj::PartitionedMPS, site) +function makesitediagonal( + obj::PartitionedMPS, sites::AbstractVector{Index{IndsT}}; baseplev=0 +) where {IndsT} return PartitionedMPS([ - _makesitediagonal(subdmps, site; baseplev=baseplev) for subdmps in values(obj) + makesitediagonal(subdmps, sites; baseplev=baseplev) for subdmps in values(obj) ]) end -function _makesitediagonal(obj::PartitionedMPS, site; baseplev=0) +function makesitediagonal(obj::PartitionedMPS, site::Index{IndsT}; baseplev=0) where {IndsT} return PartitionedMPS([ - _makesitediagonal(subdmps, site; baseplev=baseplev) for subdmps in values(obj) + makesitediagonal(subdmps, site; baseplev=baseplev) for subdmps in values(obj) ]) end @@ -258,10 +260,14 @@ end Extract diagonal of the PartitionedMPS for `s`, `s'`, ... for a given site index `s`, where `s` must have a prime level of 0. """ -function extractdiagonal(obj::PartitionedMPS, site) - return PartitionedMPS([extractdiagonal(subdmps, site) for subdmps in values(obj)]) +function extractdiagonal(obj::PartitionedMPS, sites) + return PartitionedMPS([extractdiagonal(subdmps, sites) for subdmps in values(obj)]) end function dist(a::PartitionedMPS, b::PartitionedMPS) return sqrt(sum(ITensorMPS.dist(MPS(a[k]), MPS(b[k]))^2 for k in keys(a))) end + +function _findallsiteinds_by_tag(partmps::PartitionedMPS; tag=tag) + return findallsiteinds_by_tag(only.(siteinds(partmps)); tag=tag) +end diff --git a/src/subdomainmps.jl b/src/subdomainmps.jl index 1611bd7..82c72e8 100644 --- a/src/subdomainmps.jl +++ b/src/subdomainmps.jl @@ -220,19 +220,29 @@ function _makesitediagonal( target_site::Int = only(findsites(M_, site)) M_[target_site] = _asdiagonal(M_[target_site], site; baseplev=baseplev) end - return project(M_, obj.projector) -end -function _makesitediagonal(obj::SubDomainMPS, site::Index; baseplev=0) - return _makesitediagonal(obj, [site]; baseplev=baseplev) + newproj = deepcopy(obj.projector) + for s in sites + if isprojectedat(obj.projector, s) + newproj.data[ITensors.prime(s, baseplev + 1)] = newproj.data[s] + if baseplev != 0 + newproj.data[ITensors.prime(s, baseplev)] = newproj.data[s] + delete!(newproj.data, s) + end + end + end + + return project(M_, newproj) end -function makesitediagonal(obj::SubDomainMPS, site::Index) - return _makesitediagonal(obj, site; baseplev=0) +function makesitediagonal(obj::SubDomainMPS, site::Index{IndsT}; baseplev=0) where {IndsT} + return _makesitediagonal(obj, [site]; baseplev=baseplev) end -function makesitediagonal(obj::SubDomainMPS, sites::AbstractVector{Index}) - return _makesitediagonal(obj, sites; baseplev=0) +function makesitediagonal( + obj::SubDomainMPS, sites::AbstractVector{Index{IndsT}}; baseplev=0 +) where {IndsT} + return _makesitediagonal(obj, sites; baseplev=baseplev) end function makesitediagonal(obj::SubDomainMPS, tag::String) @@ -253,18 +263,6 @@ function makesitediagonal(obj::SubDomainMPS, tag::String) return project(SubDomainMPS_diagonal, newproj) end -# FIXME: may be type unstable -# Gianluca: FIXED (?) -function _find_site_allplevs( - tensor::ITensor, site::Index{T}; maxplev=10 -)::Vector{Index{T}} where {T} - ITensors.plev(site) == 0 || error("Site index must be unprimed.") - return [ - ITensors.prime(site, plev) for - plev in 0:maxplev if ITensors.prime(site, plev) ∈ ITensors.inds(tensor) - ] -end - function extractdiagonal( obj::SubDomainMPS, sites::AbstractVector{Index{IndsT}} ) where {IndsT} @@ -280,16 +278,20 @@ function extractdiagonal( end end - projector = deepcopy(obj.projector) - for site in sites - if site' in keys(projector.data) - delete!(projector.data, site') - end + newD = Dict{Index,Int}() + # Duplicates of keys are discarded + for (k, v) in obj.projector.data + newk = ITensors.noprime(k) + newD[newk] = v end - return SubDomainMPS(MPS(tensors), projector) + return SubDomainMPS(MPS(tensors), Projector(newD)) end function extractdiagonal(obj::SubDomainMPS, tag::String)::SubDomainMPS targetsites = findallsiteinds_by_tag(unique(ITensors.noprime.(_allsites(obj))); tag=tag) return extractdiagonal(obj, targetsites) end + +function extractdiagonal(subdmps::SubDomainMPS, site::Index{IndsT}) where {IndsT} + return extractdiagonal(subdmps, [site]) +end diff --git a/src/util.jl b/src/util.jl index 7afdc28..73c7fcc 100644 --- a/src/util.jl +++ b/src/util.jl @@ -124,12 +124,14 @@ function findallsites_by_tag( _valid_tag(tag) || error("Invalid tag: $tag") # 1) Check if there is an Index with exactly `tag` - idx = findall(hastags(tag), sites) - if !isempty(idx) - if length(idx) > 1 - error("Found more than one site index with tag $(tag)!") + if tag != "" + idx = findall(hastags(tag), sites) + if !isempty(idx) + if length(idx) > 1 + error("Found more than one site index with tag $(tag)!") + end + return idx end - return idx end # 2) If not found, search for tag=1, tag=2, ... @@ -169,13 +171,15 @@ function findallsites_by_tag( sitesflatten = collect(Iterators.flatten(sites)) - idx_exact = findall(i -> hastags(i, tag) && hasplev(i, 0), sitesflatten) - if !isempty(idx_exact) - if length(idx_exact) > 1 - error("Found more than one site index with tag '$tag'!") + if tag != "" + idx_exact = findall(i -> hastags(i, tag) && hasplev(i, 0), sitesflatten) + if !isempty(idx_exact) + if length(idx_exact) > 1 + error("Found more than one site index with tag '$tag'!") + end + # Return a single position + return [sites_dict[sitesflatten[only(idx_exact)]]] end - # Return a single position - return [sites_dict[sitesflatten[only(idx_exact)]]] end result = NTuple{2,Int}[] @@ -201,6 +205,18 @@ function findallsiteinds_by_tag( return [sites[i][j] for (i, j) in positions] end +# FIXME: may be type unstable +# Gianluca: FIXED (?) +function _find_site_allplevs( + tensor::ITensor, site::Index{T}; maxplev=10 +)::Vector{Index{T}} where {T} + ITensors.plev(site) == 0 || error("Site index must be unprimed.") + return [ + ITensors.prime(site, plev) for + plev in 0:maxplev if ITensors.prime(site, plev) ∈ ITensors.inds(tensor) + ] +end + function makesitediagonal(M::AbstractMPS, tag::String)::MPS M_ = deepcopy(MPO(collect(M))) sites = siteinds(M_) @@ -225,3 +241,52 @@ function _extract_diagonal(t, site::Index{T}, site2::Index{T}) where {T<:Number} end return ITensor(newdata, restinds..., site) end + +""" +Contract two adjacent tensors in MPO +""" +function combinesites(M::MPO, site1::Index, site2::Index) + p1 = findsite(M, site1) + p2 = findsite(M, site2) + p1 === nothing && error("Not found $site1") + p2 === nothing && error("Not found $site2") + abs(p1 - p2) == 1 || error( + "$site1 and $site2 are found at indices $p1 and $p2. They must be on two adjacent sites.", + ) + tensors = ITensors.data(M) + idx = min(p1, p2) + tensor = tensors[idx] * tensors[idx + 1] + deleteat!(tensors, idx:(idx + 1)) + insert!(tensors, idx, tensor) + return MPO(tensors) +end + +function combinesites( + sites::Vector{Vector{Index{IndsT}}}, + site1::AbstractVector{Index{IndsT}}, + site2::AbstractVector{Index{IndsT}}, +) where {IndsT} + length(site1) == length(site2) || error("Length mismatch") + for (s1, s2) in zip(site1, site2) + sites = combinesites(sites, s1, s2) + end + return sites +end + +function combinesites( + sites::Vector{Vector{Index{IndsT}}}, site1::Index, site2::Index +) where {IndsT} + sites = deepcopy(sites) + p1 = findfirst(x -> x[1] == site1, sites) + p2 = findfirst(x -> x[1] == site2, sites) + if p1 === nothing || p2 === nothing + error("Site not found") + end + if abs(p1 - p2) != 1 + error("Sites are not adjacent") + end + deleteat!(sites, min(p1, p2)) + deleteat!(sites, min(p1, p2)) + insert!(sites, min(p1, p2), [site1, site2]) + return sites +end diff --git a/test/_util.jl b/test/_util.jl index 1865d23..e4ba6b1 100644 --- a/test/_util.jl +++ b/test/_util.jl @@ -20,3 +20,7 @@ function _random_mpo( ) where {T} return _random_mpo(Random.default_rng(), sites; linkdims) end + +function _evaluate(Ψ::MPS, sites, index::Vector{Int}) + return only(reduce(*, Ψ[n] * onehot(sites[n] => index[n]) for n in 1:(length(Ψ)))) +end diff --git a/test/automul_tests.jl b/test/automul_tests.jl new file mode 100644 index 0000000..f002450 --- /dev/null +++ b/test/automul_tests.jl @@ -0,0 +1,112 @@ +using Test + +using Random +using ITensors +using ITensorMPS + +import PartitionedMPSs: + PartitionedMPSs, + PartitionedMPS, + SubDomainMPS, + makesitediagonal, + extractdiagonal, + project, + elemmul, + automul, + default_cutoff + +import FastMPOContractions as FMPOC + +@testset "automul.jl" begin + @testset "element-wise product" begin + Random.seed!(1234) + N = 5 + L = 10 # Bond dimension + d = 2 # Local dimension + + sites = [Index(d, "Qubit, n=$n") for n in 1:N] + sites_vec = [[x] for x in sites] + + Ψ = ITensorMPS.convert(MPS, _random_mpo(sites_vec; linkdims=L)) + dummy_subdmps = SubDomainMPS(Ψ) + + proj_lev_l = 2 # Max projected index left tensor + proj_lev_r = 3 # Max projected index right tensor + + proj_l = vec([ + Dict(zip(sites, combo)) for + combo in Iterators.product((1:d for _ in 1:proj_lev_l)...) + ]) + + proj_r = vec([ + Dict(zip(sites, combo)) for + combo in Iterators.product((1:d for _ in 1:proj_lev_r)...) + ]) + + partΨ_l = PartitionedMPS(project.(Ref(Ψ), proj_l)) + partΨ_r = PartitionedMPS(project.(Ref(Ψ), proj_r)) + + diag_dummy_l = makesitediagonal(dummy_subdmps, sites; baseplev=1) + diag_dummy_r = makesitediagonal(dummy_subdmps, sites; baseplev=0) + + elemmul_dummy = extractdiagonal( + PartitionedMPSs.contract(diag_dummy_l, diag_dummy_r; alg="zipup"), sites + ) + + element_prod = elemmul(partΨ_l, partΨ_r) + mps_element_prod = MPS(element_prod) + + @test mps_element_prod ≈ MPS(elemmul_dummy) + + test_points = [[rand(1:d) for __ in 1:N] for _ in 1:1000] + + isapprox( + [_evaluate(mps_element_prod, sites, p) for p in test_points], + [_evaluate(Ψ, sites, p)^2 for p in test_points]; + atol=sqrt(default_cutoff()), # default_cutoff() = 1e-25 is the contraction cutoff + ) + end + + @testset "matmul" begin + N = 10 + d = 2 + L = 5 + + sites_m = [Index(d, "Qubit, m=$m") for m in 1:N] + sites_n = [Index(d, "Qubit, n=$n") for n in 1:N] + sites_l = [Index(d, "Qubit, l=$l") for l in 1:N] + sites_mn = collect(Iterators.flatten(collect.(zip(sites_m, sites_n)))) + sites_nl = collect(Iterators.flatten(collect.(zip(sites_n, sites_l)))) + + Ψ_l = ITensorMPS.convert(MPS, _random_mpo([[x] for x in sites_mn]; linkdims=L)) + Ψ_r = ITensorMPS.convert(MPS, _random_mpo([[x] for x in sites_nl]; linkdims=L)) + + proj_lev_l = 4 + proj_lev_r = 6 + + proj_l = vec([ + Dict(zip(sites_mn, combo)) for + combo in Iterators.product((1:d for _ in 1:proj_lev_l)...) + ]) + + proj_r = vec([ + Dict(zip(sites_nl, combo)) for + combo in Iterators.product((1:d for _ in 1:proj_lev_r)...) + ]) + + partΨ_l = PartitionedMPS(project.(Ref(Ψ_l), proj_l)) + partΨ_r = PartitionedMPS(project.(Ref(Ψ_r), proj_r)) + + matmul = automul( + partΨ_l, partΨ_r; alg="zipup", tag_row="m", tag_shared="n", tag_col="l" + ) + mps_matmul = MPS(matmul) + + naive_matmul = FMPOC.contract_mpo_mpo( + MPO(collect(Ψ_l)), MPO(collect(Ψ_r)); alg="naive" + ) + mps_naive_matmul = ITensorMPS.convert(MPS, naive_matmul) + + @test mps_matmul ≈ mps_naive_matmul + end +end diff --git a/test/bak/conversion_tests.jl b/test/bak/conversion_tests.jl index ecd34d1..9f8ae02 100644 --- a/test/bak/conversion_tests.jl +++ b/test/bak/conversion_tests.jl @@ -5,7 +5,7 @@ using ITensors import TensorCrossInterpolation as TCI import TCIAlgorithms as TCIA using TCIITensorConversion -import PartitionedMPSs: SubDomainMPS, PartitionedMPS +import PartitionedMPSs: PartitionedMPSs, SubDomainMPS, PartitionedMPS #import FastMPOContractions as FMPOC #import Quantics: asMPO #using Quantics: Quantics diff --git a/test/contract_tests.jl b/test/contract_tests.jl index f2a4750..26720c9 100644 --- a/test/contract_tests.jl +++ b/test/contract_tests.jl @@ -1,5 +1,6 @@ using Test -import PartitionedMPSs: PartitionedMPSs, Projector, project, SubDomainMPS, projcontract +import PartitionedMPSs: + PartitionedMPSs, PartitionedMPS, Projector, project, SubDomainMPS, projcontract import FastMPOContractions as FMPOC asMPO(M::AbstractMPS) = MPO(collect(M)) diff --git a/test/runtests.jl b/test/runtests.jl index fc186cf..97e7ed9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ include("subdomainmps_tests.jl") include("partitionedmps_tests.jl") include("contract_tests.jl") include("patching_tests.jl") -#include("util_tests.jl") +include("util_tests.jl") +include("automul_tests.jl") -#include("automul_tests.jl") +# include("bak/conversion_tests.jl") From c059811d0418f84d6de30834a723862aef6ddccf Mon Sep 17 00:00:00 2001 From: Gianluca Grosso Date: Mon, 3 Mar 2025 23:02:59 +0100 Subject: [PATCH 03/11] Implement parallel patch contraction --- Project.toml | 2 ++ src/PartitionedMPSs.jl | 2 ++ src/contract.jl | 68 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8d805c0..31156a4 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.5.4" [deps] Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037" CoverageTools = "c36e975a-824b-4404-a568-ef97ca766997" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" FastMPOContractions = "f6e391d2-8ffa-4d7a-98cd-7e70024481ca" ITensorMPS = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2" @@ -17,6 +18,7 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" [compat] Coverage = "1" CoverageTools = "1" +Distributed = "1" EllipsisNotation = "1" FastMPOContractions = "0.2.5" ITensorMPS = "0.3.2" diff --git a/src/PartitionedMPSs.jl b/src/PartitionedMPSs.jl index 8b9ce86..38b004b 100644 --- a/src/PartitionedMPSs.jl +++ b/src/PartitionedMPSs.jl @@ -10,6 +10,8 @@ import ITensors.TagSets: hastag, hastags import FastMPOContractions as FMPOC +using Distributed + default_cutoff() = 1e-25 default_maxdim() = typemax(Int) diff --git a/src/contract.jl b/src/contract.jl index 35e5394..35641b3 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -143,10 +143,15 @@ function contract( cutoff=default_cutoff(), maxdim=default_maxdim(), patchorder=Index[], + parallel::Bool=false, kwargs..., )::Union{PartitionedMPS} M = PartitionedMPS() - return contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, kwargs...) + if parallel + return parallel_contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, kwargs...) + else + return contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, kwargs...) + end end """ @@ -186,3 +191,64 @@ function contract!( end return M end + +function parallel_contract!( + M::PartitionedMPS, + M1::PartitionedMPS, + M2::PartitionedMPS; + alg="zipup", + cutoff=default_cutoff(), + maxdim=default_maxdim(), + patchorder=Index[], + overwrite=true, + kwargs..., +)::Union{PartitionedMPS} + blocks_to_sets = OrderedDict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() + + for b1 in values(M1) + for b2 in values(M2) + if hasoverlap(b1.projector, b2.projector) + r = _projector_after_contract(b1, b2)[1] + if haskey(blocks_to_sets, r) + set1, set2 = blocks_to_sets[r] + push!(set1, b1) + push!(set2, b2) + else + blocks_to_sets[r] = (Set([b1]), Set([b2])) + end + end + end + end + + for b1 in keys(blocks_to_sets), b2 in keys(blocks_to_sets) + if b1 != b2 && hasoverlap(b1, b2) + error("After contraction, projectors must not overlap.") + end + end + + # Builds tasks to parallelise + tasks = Vector{Tuple{Projector,Vector{SubDomainMPS},Vector{SubDomainMPS}}}() + for (proj, (set1, set2)) in blocks_to_sets + if haskey(M.data, proj) && !overwrite + continue + end + push!(tasks, (proj, collect(set1), collect(set2))) + end + + function process_task(task) + proj, M1_subs, M2_subs = task + return projcontract( + M1_subs, M2_subs, proj; alg, cutoff, maxdim, patchorder, kwargs... + ) + end + + results_parallel = pmap(process_task, tasks) + + for res in results_parallel + if res !== nothing + append!(M, res) + end + end + + return M +end From 7b91035d69e7fa4da95a7afcc3e4d6345ee604a7 Mon Sep 17 00:00:00 2001 From: Gianluca Grosso Date: Mon, 10 Mar 2025 01:37:42 +0100 Subject: [PATCH 04/11] Fix parallel contraction --- src/PartitionedMPSs.jl | 1 + src/contract.jl | 142 +++++++++++++++++++++++++++++++++-------- 2 files changed, 115 insertions(+), 28 deletions(-) diff --git a/src/PartitionedMPSs.jl b/src/PartitionedMPSs.jl index 38b004b..e68b67b 100644 --- a/src/PartitionedMPSs.jl +++ b/src/PartitionedMPSs.jl @@ -11,6 +11,7 @@ import ITensors.TagSets: hastag, hastags import FastMPOContractions as FMPOC using Distributed +using Base.Threads default_cutoff() = 1e-25 default_maxdim() = typemax(Int) diff --git a/src/contract.jl b/src/contract.jl index 35641b3..c8896e2 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -98,19 +98,47 @@ function projcontract( kwargs..., )::Union{Nothing,Vector{SubDomainMPS}} results = SubDomainMPS[] - #T1 = time_ns() - for M1_ in M1 - for M2_ in M2 - #t1 = time_ns() - r = projcontract(M1_, M2_, proj; alg, cutoff, maxdim, kwargs...) - #t2 = time_ns() + + # Precollect the pairs for threading + pairinfo = vec([(m1, m2, maxlinkdim(m1) * maxlinkdim(m2)) for m1 in M1, m2 in M2]) + # Heavy contraction first + sort!(pairinfo; by=x -> x[3], rev=true) + # Lock for threaded computation + local_lock = ReentrantLock() + + if Threads.nthreads() > 1 + nT = nthreads() + chunked_pairs = [Vector{Tuple{SubDomainMPS,SubDomainMPS}}() for _ in 1:nT] + # Equally divide expensive computations btw threads + for (i, (m1, m2, _)) in enumerate(pairinfo) + t = ((i - 1) % nT) + 1 + push!(chunked_pairs[t], (m1, m2)) + end + + @threads for t in 1:nT + local_buffer = SubDomainMPS[] + + for (m1, m2) in chunked_pairs[t] + r = projcontract(m1, m2, proj; alg, cutoff, maxdim, kwargs...) + + if r !== nothing + push!(local_buffer, r) # Thread-local accumulation + end + end + + # Lock is held only briefly to merge partial results + lock(local_lock) do + append!(results, local_buffer) + end + end + else + for (m1, m2, _) in pairinfo + r = projcontract(m1, m2, proj; alg, cutoff, maxdim, kwargs...) if r !== nothing push!(results, r) end end end - #T2 = time_ns() - #println("projcontract, all: $((T2 - T1)*1e-9) s") if isempty(results) return nothing @@ -170,25 +198,68 @@ function contract!( overwrite=true, kwargs..., )::Union{PartitionedMPS} - blocks = OrderedSet(( - _projector_after_contract(b1, b2)[1] for b1 in values(M1), b2 in values(M2) - )) + blocks = OrderedSet{Projector}() + blocks_to_sets = OrderedDict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() + + function add_entry!(blocks::OrderedSet{T}, proj::T) where {T} + current_block = proj + for existing in copy(blocks) + if hasoverlap(existing, proj) + delete!(blocks, existing) + fused_proj = existing | proj + current_block = add_entry!(blocks, fused_proj) + end + end + push!(blocks, proj) + return current_block + end + + # Add a result's patch only if not overlapping with previously inserted ones + # Each block is obtained from contractable patches of the original factors + # Keep track of which SubDomainMPSs generate each resulting patch + for b1 in values(M1), b2 in values(M2) + if hasoverlap(b1.projector, b2.projector) + block_key = add_entry!(blocks, _projector_after_contract(b1, b2)[1]) + if haskey(blocks_to_sets, block_key) + set1, set2 = blocks_to_sets[block_key] + push!(set1, b1) + push!(set2, b2) + else + blocks_to_sets[block_key] = (Set([b1]), Set([b2])) + end + end + end + for b1 in blocks, b2 in blocks if b1 != b2 && hasoverlap(b1, b2) error("After contraction, projectors must not overlap.") end end - M1_::Vector{SubDomainMPS} = collect(values(M1)) - M2_::Vector{SubDomainMPS} = collect(values(M2)) - for b in blocks - if haskey(M.data, b) && !overwrite + + # Builds tasks to parallelise + tasks = Vector{Tuple{Projector,Vector{SubDomainMPS},Vector{SubDomainMPS}}}() + for (proj, (set1, set2)) in blocks_to_sets + if haskey(M.data, proj) && !overwrite continue end - res = projcontract(M1_, M2_, b; alg, cutoff, maxdim, patchorder, kwargs...) + push!(tasks, (proj, collect(set1), collect(set2))) + end + + function process_task(task) + proj, M1_subs, M2_subs = task + return projcontract( + M1_subs, M2_subs, proj; alg, cutoff, maxdim, patchorder, kwargs... + ) + end + + results = map(process_task, tasks) + + for res in results if res !== nothing append!(M, res) end end + return M end @@ -203,24 +274,39 @@ function parallel_contract!( overwrite=true, kwargs..., )::Union{PartitionedMPS} + blocks = OrderedSet{Projector}() blocks_to_sets = OrderedDict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() - for b1 in values(M1) - for b2 in values(M2) - if hasoverlap(b1.projector, b2.projector) - r = _projector_after_contract(b1, b2)[1] - if haskey(blocks_to_sets, r) - set1, set2 = blocks_to_sets[r] - push!(set1, b1) - push!(set2, b2) - else - blocks_to_sets[r] = (Set([b1]), Set([b2])) - end + function add_entry!(blocks::OrderedSet{T}, proj::T) where {T} + current_block = proj + for existing in copy(blocks) + if hasoverlap(existing, proj) + delete!(blocks, existing) + fused_proj = existing | proj + current_block = add_entry!(blocks, fused_proj) end end + push!(blocks, proj) + return current_block end - for b1 in keys(blocks_to_sets), b2 in keys(blocks_to_sets) + # Add a result's patch only if not overlapping with previously inserted ones + # Each block is obtained from contractable patches of the original factors + # Keep track of which SubDomainMPSs generate each resulting patch + for b1 in values(M1), b2 in values(M2) + if hasoverlap(b1.projector, b2.projector) + block_key = add_entry!(blocks, _projector_after_contract(b1, b2)[1]) + if haskey(blocks_to_sets, block_key) + set1, set2 = blocks_to_sets[block_key] + push!(set1, b1) + push!(set2, b2) + else + blocks_to_sets[block_key] = (Set([b1]), Set([b2])) + end + end + end + + for b1 in blocks, b2 in blocks if b1 != b2 && hasoverlap(b1, b2) error("After contraction, projectors must not overlap.") end From def7c64205e262d9f668b51ac31a8f99146dabc2 Mon Sep 17 00:00:00 2001 From: Gianluca Grosso Date: Mon, 10 Mar 2025 23:26:56 +0100 Subject: [PATCH 05/11] Add only distributed parallel contract --- src/contract.jl | 217 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 158 insertions(+), 59 deletions(-) diff --git a/src/contract.jl b/src/contract.jl index c8896e2..40daffa 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -171,14 +171,18 @@ function contract( cutoff=default_cutoff(), maxdim=default_maxdim(), patchorder=Index[], - parallel::Bool=false, + parallel::Symbol=:serial, kwargs..., )::Union{PartitionedMPS} M = PartitionedMPS() - if parallel + if parallel == :distributed_thread return parallel_contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, kwargs...) - else + elseif parallel == :distributed + return distribute_contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, kwargs...) + elseif parallel == :serial return contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, kwargs...) + else + error("Symbol $(parallel) not recongnized.") end end @@ -198,39 +202,22 @@ function contract!( overwrite=true, kwargs..., )::Union{PartitionedMPS} - blocks = OrderedSet{Projector}() - blocks_to_sets = OrderedDict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() - - function add_entry!(blocks::OrderedSet{T}, proj::T) where {T} - current_block = proj - for existing in copy(blocks) - if hasoverlap(existing, proj) - delete!(blocks, existing) - fused_proj = existing | proj - current_block = add_entry!(blocks, fused_proj) - end - end - push!(blocks, proj) - return current_block - end - - # Add a result's patch only if not overlapping with previously inserted ones - # Each block is obtained from contractable patches of the original factors - # Keep track of which SubDomainMPSs generate each resulting patch - for b1 in values(M1), b2 in values(M2) - if hasoverlap(b1.projector, b2.projector) - block_key = add_entry!(blocks, _projector_after_contract(b1, b2)[1]) - if haskey(blocks_to_sets, block_key) - set1, set2 = blocks_to_sets[block_key] - push!(set1, b1) - push!(set2, b2) + blocks_to_sets = Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() + + for m1 in values(M1), m2 in values(M2) + if hasoverlap(m1.projector, m2.projector) + block = add_entry!(blocks_to_sets, _projector_after_contract(m1, m2)[1]) + if haskey(blocks_to_sets, block) + set1, set2 = blocks_to_sets[block] + push!(set1, m1) + push!(set2, m2) else - blocks_to_sets[block_key] = (Set([b1]), Set([b2])) + blocks_to_sets[block] = (Set([m1]), Set([m2])) end end end - for b1 in blocks, b2 in blocks + for b1 in keys(blocks_to_sets), b2 in keys(blocks_to_sets) if b1 != b2 && hasoverlap(b1, b2) error("After contraction, projectors must not overlap.") end @@ -274,39 +261,22 @@ function parallel_contract!( overwrite=true, kwargs..., )::Union{PartitionedMPS} - blocks = OrderedSet{Projector}() - blocks_to_sets = OrderedDict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() - - function add_entry!(blocks::OrderedSet{T}, proj::T) where {T} - current_block = proj - for existing in copy(blocks) - if hasoverlap(existing, proj) - delete!(blocks, existing) - fused_proj = existing | proj - current_block = add_entry!(blocks, fused_proj) - end - end - push!(blocks, proj) - return current_block - end - - # Add a result's patch only if not overlapping with previously inserted ones - # Each block is obtained from contractable patches of the original factors - # Keep track of which SubDomainMPSs generate each resulting patch - for b1 in values(M1), b2 in values(M2) - if hasoverlap(b1.projector, b2.projector) - block_key = add_entry!(blocks, _projector_after_contract(b1, b2)[1]) - if haskey(blocks_to_sets, block_key) - set1, set2 = blocks_to_sets[block_key] - push!(set1, b1) - push!(set2, b2) + blocks_to_sets = Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() + + for m1 in values(M1), m2 in values(M2) + if hasoverlap(m1.projector, m2.projector) + block = add_entry!(blocks_to_sets, _projector_after_contract(m1, m2)[1]) + if haskey(blocks_to_sets, block) + set1, set2 = blocks_to_sets[block] + push!(set1, m1) + push!(set2, m2) else - blocks_to_sets[block_key] = (Set([b1]), Set([b2])) + blocks_to_sets[block] = (Set([m1]), Set([m2])) end end end - for b1 in blocks, b2 in blocks + for b1 in keys(blocks_to_sets), b2 in keys(blocks_to_sets) if b1 != b2 && hasoverlap(b1, b2) error("After contraction, projectors must not overlap.") end @@ -338,3 +308,132 @@ function parallel_contract!( return M end + +function add_entry!( + dict::Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}, proj::Projector +) + # Iterate over a copy of keys to avoid modifying the dict while looping. + for existing in collect(keys(dict)) + if hasoverlap(existing, proj) + fused_proj = existing | proj + # Save the current value for the overlapping key. + val = dict[existing] + # Remove the old key (this deletes its associated value). + delete!(dict, existing) + # Recursively update with the fused projector. + new_key = add_entry!(dict, fused_proj) + # If new_key is already present, merge the values; otherwise, insert the saved value. + if haskey(dict, new_key) + old_val = dict[new_key] + dict[new_key] = (union(old_val[1], val[1]), union(old_val[2], val[2])) + else + dict[new_key] = val + end + return new_key + end + end + # If no overlapping key is found, then ensure proj is in the dictionary. + if !haskey(dict, proj) + dict[proj] = (Set{SubDomainMPS}(), Set{SubDomainMPS}()) + end + return proj +end + +function distribute_contract!( + M::PartitionedMPS, + M1::PartitionedMPS, + M2::PartitionedMPS; + alg="zipup", + alg_sum="fit", + cutoff=default_cutoff(), + maxdim=default_maxdim(), + patchorder=Index[], + overwrite=true, + kwargs..., +)::Union{PartitionedMPS} + blocks_to_sets = Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() + + for m1 in values(M1), m2 in values(M2) + if hasoverlap(m1.projector, m2.projector) + block = add_entry!(blocks_to_sets, _projector_after_contract(m1, m2)[1]) + if haskey(blocks_to_sets, block) + set1, set2 = blocks_to_sets[block] + push!(set1, m1) + push!(set2, m2) + else + blocks_to_sets[block] = (Set([m1]), Set([m2])) + end + end + end + + for b1 in keys(blocks_to_sets), b2 in keys(blocks_to_sets) + if b1 != b2 && hasoverlap(b1, b2) + error("After contraction, projectors must not overlap.") + end + end + + tasks = Vector{Tuple{Projector,SubDomainMPS,SubDomainMPS}}() + for (proj, (set1, set2)) in blocks_to_sets + for subdmps1 in set1, subdmps2 in set2 + if haskey(M.data, proj) && !overwrite + continue + end + push!(tasks, (proj, subdmps1, subdmps2)) + end + end + + function process_task(task_tuple; alg, cutoff, maxdim, kwargs...) + # Unpack the tuple + proj, subdmps1, subdmps2 = task_tuple + res = projcontract(subdmps1, subdmps2, proj; alg, cutoff, maxdim, kwargs...) + return (proj, res) + end + + results = pmap(task -> process_task(task; alg, cutoff, maxdim, kwargs...), tasks) + valid_results = filter(x -> x[2] !== nothing, results) + + block_group = Dict{Projector,Vector{SubDomainMPS}}() + for (b, subdmps) in valid_results + if haskey(block_group, b) + push!(block_group[b], subdmps) + else + block_group[b] = [subdmps] + end + end + + block_group_array = collect(block_group) + + function sum_blocks(group; patchorder, alg_sum, cutoff, maxdim, kwargs...) + b, subdmps_list = group + if length(subdmps_list) == 1 + return [subdmps_list[1]] + else + res = if length(patchorder) > 0 + _add_patching(subdmps_list; cutoff, maxdim, patchorder, kwargs...) + else + [_add(subdmps_list...; alg=alg_sum, cutoff, maxdim, kwargs...)] + end + return res + end + end + + summed_patches = pmap( + group -> sum_blocks( + group; + patchorder=patchorder, + alg_sum=alg_sum, + cutoff=cutoff, + maxdim=maxdim, + kwargs..., + ), + block_group_array, + ) + + for res in summed_patches + if res !== nothing + append!(M, vcat(res)) + end + end + + return M +end From 398680429c7521efc29e5a9def1960b6edf16653 Mon Sep 17 00:00:00 2001 From: Gianluca Grosso Date: Mon, 24 Mar 2025 17:15:16 +0100 Subject: [PATCH 06/11] Add prime and noprime feature --- src/partitionedmps.jl | 6 ++++++ src/projector.jl | 41 +++++++++++++++++++++++++++++++++++++++++ src/subdomainmps.jl | 33 +++++++++++++++++++++++++-------- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/src/partitionedmps.jl b/src/partitionedmps.jl index da653bb..d7c67e4 100644 --- a/src/partitionedmps.jl +++ b/src/partitionedmps.jl @@ -101,6 +101,12 @@ function prime(Ψ::PartitionedMPS, args...; kwargs...) ]) end +function noprime(Ψ::PartitionedMPS, args...; kwargs...) + return PartitionedMPS([ + noprime(subdmps, args...; kwargs...) for subdmps in values(Ψ.data) + ]) +end + """ Return the norm of the PartitionedMPS. """ diff --git a/src/projector.jl b/src/projector.jl index 6879241..397ca0e 100644 --- a/src/projector.jl +++ b/src/projector.jl @@ -122,3 +122,44 @@ function Base.isdisjoint(projectors::AbstractVector{Projector})::Bool end return true end + +""" +Remove prime level in projected indices. +""" +function noprime( + p::Projector; targetsites::Union{Nothing,AbstractVector{Index{T}}}=nothing +)::Projector where {T} + if isnothing(targetsites) + targetsites = keys(p) + elseif targetsites ⊈ keys(p) + error("Target sites are not projected indices.") + end + + if isempty(targetsites) + return p + end + + new_dict = Dict(k ∈ targetsites ? ITensors.noprime(k) => v : k => v for (k, v) in p) + return Projector(new_dict) +end + +""" +Prime projected indices. +""" +function prime( + p::Projector; targetsites::Union{Nothing,AbstractVector{Index{T}}}=nothing +)::Projector where {T} + if isnothing(targetsites) + targetsites = keys(p) + elseif targetsites ⊈ keys(p) + error("Target sites are not projected indices.") + end + + if isempty(targetsites) + return p + end + + new_dict = Dict(k ∈ targetsites ? ITensors.prime(k) => v : k => v for (k, v) in p) + + return Projector(new_dict) +end diff --git a/src/subdomainmps.jl b/src/subdomainmps.jl index 82c72e8..ff9fe6c 100644 --- a/src/subdomainmps.jl +++ b/src/subdomainmps.jl @@ -114,10 +114,28 @@ function Base.show(io::IO, obj::SubDomainMPS) end function prime(Ψ::SubDomainMPS, args...; kwargs...) + if :inds ∈ keys(kwargs) + targetsites = kwargs[:inds] + else + targetsites = nothing + end + return SubDomainMPS( ITensors.prime(MPS(Ψ), args...; kwargs...), - ITensors.prime.(siteinds(Ψ), args...; kwargs...), - Ψ.projector, + PartitionedMPSs.prime(Ψ.projector; targetsites=targetsites), + ) +end + +function noprime(Ψ::SubDomainMPS, args...; kwargs...) + if :inds ∈ keys(kwargs) + targetsites = kwargs[:inds] + else + targetsites = nothing + end + + return SubDomainMPS( + ITensors.noprime(MPS(Ψ), args...; kwargs...), + PartitionedMPSs.noprime(Ψ.projector; targetsites=targetsites), ) end @@ -136,8 +154,7 @@ function _fitsum( kwargs..., ) where {T} if !(:nsweeps ∈ keys(kwargs)) - kwargs = Dict{Symbol,Any}(kwargs) - kwargs[:nsweeps] = 1 + kwargs = merge(Dict(kwargs), Dict(:nsweeps => 1)) end Ψs = [MPS(collect(x)) for x in input_states] init_Ψ = MPS(collect(init)) @@ -158,7 +175,7 @@ function _add(ψ::AbstractMPS...; alg="fit", cutoff=1e-15, maxdim=typemax(Int), elseif alg == "fit" function f(x, y) return ITensors.truncate( - +(ITensors.Algorithm("directsum"), x, y); cutoff, maxdim + +(ITensors.Algorithm("directsum"), x, y); cutoff, maxdim, kwargs... ) end res_dm = reduce(f, ψ) @@ -170,16 +187,16 @@ function _add(ψ::AbstractMPS...; alg="fit", cutoff=1e-15, maxdim=typemax(Int), end function Base.:+( - Ψ::SubDomainMPS...; alg="directsum", cutoff=0.0, maxdim=typemax(Int), kwargs... + Ψ::SubDomainMPS...; alg="fit", cutoff=0.0, maxdim=typemax(Int), kwargs... )::SubDomainMPS return _add(Ψ...; alg=alg, cutoff=cutoff, maxdim=maxdim, kwargs...) end function _add( - Ψ::SubDomainMPS...; alg="directsum", cutoff=0.0, maxdim=typemax(Int), kwargs... + Ψ::SubDomainMPS...; alg="fit", cutoff=0.0, maxdim=typemax(Int), kwargs... )::SubDomainMPS return project( - _add([x.data for x in Ψ]...; alg=alg, cutoff=cutoff, maxdim=maxdim), + _add([x.data for x in Ψ]...; alg=alg, cutoff=cutoff, maxdim=maxdim, kwargs...), reduce(|, [x.projector for x in Ψ]), ) end From 394101dad6992ed6831ebb81a1cacf36f4f416fb Mon Sep 17 00:00:00 2001 From: Samuel Badr Date: Wed, 2 Apr 2025 15:52:15 +0200 Subject: [PATCH 07/11] Improve prime(::Projector) and add tests --- src/projector.jl | 29 +++++++++++++++++++---------- src/subdomainmps.jl | 14 ++++---------- test/partitionedmps_tests.jl | 36 +++++++++++++++++++++++++++++++++++- test/subdomainmps_tests.jl | 4 ++-- 4 files changed, 60 insertions(+), 23 deletions(-) diff --git a/src/projector.jl b/src/projector.jl index 397ca0e..14a1309 100644 --- a/src/projector.jl +++ b/src/projector.jl @@ -146,20 +146,29 @@ end """ Prime projected indices. """ -function prime( - p::Projector; targetsites::Union{Nothing,AbstractVector{Index{T}}}=nothing -)::Projector where {T} - if isnothing(targetsites) - targetsites = keys(p) - elseif targetsites ⊈ keys(p) - error("Target sites are not projected indices.") +function prime(p::Projector, plinc=1; kwargs...)::Projector + targetsites = if :inds ∈ keys(kwargs) + kwargs[:inds] + else + keys(p) end + plev = if :plev ∈ keys(kwargs) + kwargs[:plev] + end + isempty(targetsites) && return p - if isempty(targetsites) - return p + function new_ind(k) + if k ∉ targetsites + return k + end + if isnothing(plev) || ITensors.hasplev(k, plev) + return ITensors.prime(k, plinc) + else + return k + end end - new_dict = Dict(k ∈ targetsites ? ITensors.prime(k) => v : k => v for (k, v) in p) + new_dict = Dict(new_ind(k) => v for (k, v) in p) return Projector(new_dict) end diff --git a/src/subdomainmps.jl b/src/subdomainmps.jl index ff9fe6c..7162545 100644 --- a/src/subdomainmps.jl +++ b/src/subdomainmps.jl @@ -113,16 +113,10 @@ function Base.show(io::IO, obj::SubDomainMPS) return print(io, "SubDomainMPS projected on $(obj.projector.data)") end -function prime(Ψ::SubDomainMPS, args...; kwargs...) - if :inds ∈ keys(kwargs) - targetsites = kwargs[:inds] - else - targetsites = nothing - end - +function prime(Ψ::SubDomainMPS, plinc=1; kwargs...) return SubDomainMPS( - ITensors.prime(MPS(Ψ), args...; kwargs...), - PartitionedMPSs.prime(Ψ.projector; targetsites=targetsites), + ITensors.prime(MPS(Ψ), plinc; kwargs...), + PartitionedMPSs.prime(Ψ.projector, plinc; kwargs...), ) end @@ -135,7 +129,7 @@ function noprime(Ψ::SubDomainMPS, args...; kwargs...) return SubDomainMPS( ITensors.noprime(MPS(Ψ), args...; kwargs...), - PartitionedMPSs.noprime(Ψ.projector; targetsites=targetsites), + PartitionedMPSs.noprime(Ψ.projector; targetsites), ) end diff --git a/test/partitionedmps_tests.jl b/test/partitionedmps_tests.jl index d60bb3e..da21d1f 100644 --- a/test/partitionedmps_tests.jl +++ b/test/partitionedmps_tests.jl @@ -4,7 +4,8 @@ using ITensors using ITensorMPS using Random -import PartitionedMPSs: PartitionedMPSs, Projector, project, SubDomainMPS, PartitionedMPS +import PartitionedMPSs: + PartitionedMPSs, Projector, project, SubDomainMPS, PartitionedMPS, prime, noprime @testset "partitionedmps.jl" begin @testset "two blocks" begin @@ -100,4 +101,37 @@ import PartitionedMPSs: PartitionedMPSs, Projector, project, SubDomainMPS, Parti @test diff < cutoff_global end end + + @testset "prime" begin + Random.seed!(1234) + N = 3 + sitesx = [Index(2, "x=$n") for n in 1:N] + sitesy = [Index(2, "y=$n") for n in 1:N] + + sites = collect(collect.(zip(sitesx, sitesy))) + + Ψ = MPS(collect(_random_mpo(sites))) + + prjΨ = SubDomainMPS(Ψ) + + prjΨ1 = project(prjΨ, Dict(sitesx[1] => 1)) + prjΨ2 = project(prjΨ, Dict(sitesx[1] => 2)) + + Ψreconst = PartitionedMPS(prjΨ1) + PartitionedMPS(prjΨ2) + + Ψreconst_x3prime = prime(Ψreconst, 3; inds=sitesx) + + @test Set(Iterators.flatten(siteinds(Ψreconst_x3prime))) == + Set(union(ITensors.prime.(sitesx, 3), sitesy)) + + @test Set(keys(Ψreconst_x3prime)) == Set([ + Projector(ITensors.prime(sitesx[1], 3) => 1), + Projector(ITensors.prime(sitesx[1], 3) => 2), + ]) + + Ψreconst_x3prime_yprime = prime(Ψreconst_x3prime; plev=0) + + @test Set(Iterators.flatten(siteinds(Ψreconst_x3prime_yprime))) == + Set(union(ITensors.prime.(sitesx, 3), ITensors.prime.(sitesy, 1))) + end end diff --git a/test/subdomainmps_tests.jl b/test/subdomainmps_tests.jl index 6e327a9..289d859 100644 --- a/test/subdomainmps_tests.jl +++ b/test/subdomainmps_tests.jl @@ -89,7 +89,7 @@ import PartitionedMPSs: index_dict = Dict{Index{Int},Vector{Int}}() for (i, el) in enumerate(ind) - baseind = noprime(el) + baseind = ITensors.noprime(el) if haskey(index_dict, baseind) push!(index_dict[baseind], i) else @@ -101,7 +101,7 @@ import PartitionedMPSs: isdiagonalelement = all(allequal(val[i] for i in is) for is in repeated_indices) if isdiagonalelement - nondiaginds = unique(noprime(i) => v for (i, v) in indval) + nondiaginds = unique(ITensors.noprime(i) => v for (i, v) in indval) diag_ok = diag_ok && (psi_diag[indval...] == psi[nondiaginds...]) else offdiag_ok = offdiag_ok && iszero(psi_diag[indval...]) From 543662a6bfd0597839ac02702fbd8772fb84fcd7 Mon Sep 17 00:00:00 2001 From: Gianluca Grosso Date: Wed, 16 Apr 2025 17:14:10 +0200 Subject: [PATCH 08/11] Update coverage and Project.toml --- Project.toml | 14 +- coverage/lcov.info | 1841 +++++++++++++++------------------- src/contract.jl | 236 ++--- test/bak/conversion_tests.jl | 8 +- test/runtests.jl | 5 +- 5 files changed, 878 insertions(+), 1226 deletions(-) diff --git a/Project.toml b/Project.toml index 31156a4..59dea03 100644 --- a/Project.toml +++ b/Project.toml @@ -4,32 +4,32 @@ authors = ["Hiroshi Shinaoka and contributors"] version = "0.5.4" [deps] -Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037" -CoverageTools = "c36e975a-824b-4404-a568-ef97ca766997" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" FastMPOContractions = "f6e391d2-8ffa-4d7a-98cd-7e70024481ca" ITensorMPS = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2" ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LocalCoverage = "5f6e1e16-694c-5876-87ef-16b5274f298e" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +[sources] +TCIAlgorithms = {url = "https://github.com/tensor4all/TCIAlgorithms.jl.git"} + [compat] -Coverage = "1" -CoverageTools = "1" Distributed = "1" EllipsisNotation = "1" FastMPOContractions = "0.2.5" ITensorMPS = "0.3.2" ITensors = "0.7" -LocalCoverage = "0.8" OrderedCollections = "1.6.3" julia = "1.6" [extras] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +TCIAlgorithms = "baf62351-2e82-41dd-9129-4f5768a618e1" +TCIITensorConversion = "9f0aa9f4-9415-4e6a-8795-331ebf40aa04" +TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random"] +test = ["Test", "Random", "TCIAlgorithms", "TensorCrossInterpolation", "TCIITensorConversion"] diff --git a/coverage/lcov.info b/coverage/lcov.info index 52528b3..62bf682 100644 --- a/coverage/lcov.info +++ b/coverage/lcov.info @@ -1,1079 +1,824 @@ -SF:src/TCIAlgorithms.jl -DA:1,13 +SF:src/PartitionedMPSs.jl +DA:16,0 +DA:17,47 LH:1 -LF:1 -end_of_record -SF:src/adaptivematmul.jl -DA:1,12 -DA:8,6 -DA:10,6 -DA:11,18 -DA:12,214 -DA:13,114 -DA:14,100 -DA:16,120 -DA:19,6 -DA:21,6 -DA:23,6 -DA:24,6 -DA:25,220 -DA:26,110 -DA:28,6 -DA:35,12 -DA:42,6 -DA:43,20 -DA:44,20 -DA:45,40 -DA:46,70 -DA:47,740 -DA:48,628 -DA:50,112 -DA:51,112 -DA:52,226 -DA:55,110 -DA:56,110 -DA:58,2 -DA:59,2 -DA:60,12 -DA:61,12 -DA:62,12 -DA:64,740 -DA:65,120 -DA:66,20 -DA:67,6 -DA:69,14 -DA:70,12 -DA:71,28 -DA:72,110 -DA:73,110 -DA:74,50 -DA:75,6 -DA:78,98 -DA:83,49 -DA:89,49 -DA:90,49 -DA:91,364 -DA:92,0 -DA:94,182 -DA:96,49 -DA:97,33 -DA:100,149 -DA:102,32 -DA:104,11 -DA:107,5 -DA:114,2 -DA:121,1 -DA:122,5 -DA:123,24 -DA:124,16 -DA:125,8 -DA:127,8 -DA:128,32 -DA:129,8 -DA:130,8 -DA:132,32 -DA:133,19 -DA:135,1 -LH:69 -LF:70 +LF:2 end_of_record -SF:src/blockstructure.jl -DA:8,4 -DA:9,4 -DA:10,26 -DA:11,17 -DA:12,1 -DA:15,30 -DA:16,3 -DA:20,5 -DA:22,22 -DA:23,26 -DA:24,4 -DA:26,14 +SF:src/adaptivemul.jl +DA:8,0 +DA:9,0 +DA:10,0 +DA:11,0 +DA:12,0 +DA:14,0 +DA:15,0 +DA:17,0 +DA:21,0 +DA:23,0 +DA:24,0 +DA:26,0 DA:29,0 -LH:12 -LF:13 +DA:36,0 +DA:37,0 +DA:38,0 +DA:39,0 +DA:48,0 +DA:57,0 +DA:59,0 +DA:60,0 +DA:61,0 +DA:62,0 +DA:64,0 +DA:65,0 +DA:67,0 +DA:69,0 +DA:73,0 +DA:75,0 +DA:76,0 +DA:77,0 +DA:79,0 +DA:80,0 +DA:82,0 +LH:0 +LF:34 end_of_record -SF:src/container.jl -DA:6,47 -DA:7,47 -DA:8,47 -DA:9,47 -DA:10,608 -DA:11,608 -DA:12,94 -DA:13,47 -DA:17,240 -DA:18,268 -DA:19,29 -DA:21,182 -DA:24,0 -DA:25,4 -DA:29,15 -DA:30,15 -DA:33,10 -DA:34,10 -DA:41,15 -DA:44,15 -DA:47,8 -DA:50,4 -DA:51,12 -DA:59,1 -DA:60,1 -DA:63,2454 -DA:64,2454 -DA:67,1 +SF:src/automul.jl +DA:16,2 +DA:24,1 +DA:25,1 +DA:27,1 +DA:29,1 +DA:31,1 +DA:32,1 +DA:34,1 +DA:36,1 +DA:38,1 +DA:61,2 +DA:72,1 DA:73,1 -DA:74,1 -DA:75,0 +DA:75,1 +DA:76,1 +DA:77,1 DA:78,1 -DA:79,1 -DA:80,6 -DA:81,3 +DA:80,1 +DA:81,1 +DA:82,1 DA:83,1 -DA:84,1 DA:85,1 -DA:87,1 -DA:94,1 -DA:95,4 +DA:86,1 +DA:88,1 +DA:89,1 +DA:91,1 +DA:93,1 +DA:95,1 +DA:97,1 DA:99,1 -DA:104,1 -DA:105,1 -DA:113,0 -DA:114,0 -DA:117,0 -DA:118,0 -DA:121,8 -DA:122,8 -LH:44 -LF:50 -end_of_record -SF:src/crossinterpolate.jl -DA:6,231 -DA:11,231 -DA:12,231 -DA:13,231 -DA:14,231 -DA:17,1 -DA:19,125154 -DA:20,125154 -DA:22,1 -DA:23,1 -DA:26,327149 -DA:31,327149 -DA:32,0 -DA:35,327149 -DA:36,327149 -DA:39,327149 -DA:43,327149 -DA:44,327149 -DA:45,327149 -DA:46,355925 -DA:49,327149 -DA:52,327149 -DA:59,654304 -DA:62,654304 -DA:63,49476 -DA:66,604828 -DA:67,604828 -DA:68,604828 -DA:69,604828 -DA:70,7422974 -DA:72,6658365 -DA:74,6658365 -DA:76,7422974 -DA:77,604828 -DA:78,604828 -DA:80,6818146 -DA:82,7422974 -DA:83,604828 -DA:85,604828 -DA:86,2537886 -DA:87,1933058 -DA:89,604828 -DA:92,49476 -DA:95,49476 -DA:96,118848 -DA:97,118848 -DA:99,98952 -DA:100,49476 -DA:103,327152 -DA:106,654304 -DA:107,327152 -DA:114,230 -DA:127,0 -DA:128,0 -DA:131,220 -DA:132,220 -DA:147,29 -DA:163,10 -DA:164,10 -DA:169,10 -DA:170,0 -DA:174,10 -DA:189,3 -DA:196,1 -DA:201,458 -DA:210,229 -DA:211,229 -DA:225,458 -DA:226,0 -DA:231,229 -DA:232,458 -DA:234,229 -DA:239,230 -DA:240,230 -DA:241,230 -DA:243,230 -DA:245,85 -DA:250,85 -DA:251,85 -DA:252,0 -DA:254,85 -DA:257,145 -DA:258,0 -DA:259,145 -DA:260,145 -DA:261,2 -DA:262,2 -DA:263,2 -DA:265,2 -DA:267,145 -DA:271,289 -DA:272,1 -DA:274,373 -DA:277,229 -DA:288,20 -DA:291,10 -DA:292,10 -DA:293,230 -DA:295,10 -DA:298,17 -DA:306,8 -DA:307,8 -DA:317,8 -DA:318,8 -LH:97 -LF:104 -end_of_record -SF:src/distribute.jl -DA:4,11 -DA:8,11 -DA:9,11 -DA:12,22 -DA:13,11 -DA:14,22 -DA:15,7 -DA:20,22 -DA:22,22 -DA:23,24 -DA:24,250 -DA:25,195 -DA:27,250 -DA:28,65 -DA:30,270 -DA:31,22 -DA:33,11 -LH:17 -LF:17 +DA:100,1 +DA:101,10 +DA:102,10 +DA:103,0 +DA:105,10 +DA:106,5 +DA:107,5 +DA:109,5 +DA:110,5 +DA:113,10 +DA:114,1 +LH:40 +LF:41 end_of_record -SF:src/itensor.jl -DA:11,6 -DA:14,6 -DA:17,6 -DA:21,2 -DA:22,2 -DA:23,2 -DA:24,2 -DA:27,0 -DA:28,0 -DA:34,6 -DA:35,6 -DA:36,6 -DA:37,6 -DA:38,6 -DA:39,3 -DA:40,3 -DA:41,6 -DA:42,6 -DA:46,2 -DA:49,1 -DA:50,1 +SF:src/bak/conversion.jl +DA:4,3 +DA:5,3 +DA:7,3 +DA:8,3 +DA:9,3 +DA:11,3 +DA:20,3 +DA:21,6 +DA:30,9 +DA:32,3 +DA:41,3 +DA:42,3 +DA:43,12 +DA:44,24 +DA:45,6 +DA:47,36 +DA:48,21 +DA:50,3 DA:56,1 -DA:58,1 -DA:60,1 -DA:61,1 -DA:62,1 -DA:64,2 -DA:73,1 -DA:74,0 +DA:57,1 +LH:20 +LF:20 +end_of_record +SF:src/contract.jl +DA:9,2198 +DA:13,1099 +DA:14,800 +DA:16,598 +DA:18,299 +DA:19,299 +DA:22,299 +DA:26,1676 +DA:27,1676 +DA:28,1676 +DA:30,1676 +DA:32,1676 +DA:33,800 +DA:36,876 +DA:37,14001 +DA:39,876 +DA:40,17237 +DA:41,1669 +DA:43,17237 +DA:44,2459 +DA:46,17237 +DA:48,876 +DA:52,1097 +DA:53,2194 +DA:54,40906 +DA:55,0 +DA:57,10582 +DA:58,1097 +DA:64,2386 +DA:74,1193 +DA:75,1193 +DA:76,2322 +DA:77,96 +DA:80,1394 +DA:82,1097 DA:83,0 -DA:85,2 -DA:94,1 -DA:97,2 -DA:99,8 -DA:100,34 -DA:103,8 -DA:104,8 -DA:105,8 -DA:106,8 -DA:109,3 -DA:110,3 -DA:111,9 -DA:112,3 -DA:113,3 +DA:86,1097 +DA:87,1097 +DA:94,112 +DA:105,56 +DA:107,56 +DA:108,1192 +DA:109,1192 +DA:110,296 +DA:112,1192 +DA:114,56 DA:115,0 -DA:116,0 -DA:119,3 -DA:120,3 -DA:121,6 -DA:122,3 -DA:123,3 -DA:124,0 -DA:126,3 -DA:127,3 -DA:128,3 -DA:131,3 -DA:132,3 -DA:139,2 -DA:140,1 -DA:141,1 -DA:142,1 -DA:143,1 -DA:144,1 -DA:145,0 -DA:154,0 -DA:155,1 -DA:158,1 -DA:161,6 -DA:164,6 -DA:165,6 -DA:167,6 -DA:168,6 -DA:174,6 -DA:175,4 -DA:184,3 -DA:185,9 -DA:195,6 +DA:118,56 +DA:119,8 +DA:122,48 +DA:123,4 +DA:125,92 +DA:128,48 +DA:136,8 +DA:146,4 +DA:147,4 +DA:148,0 +DA:149,4 +DA:150,4 +DA:152,0 +DA:156,512 +DA:160,512 +DA:161,9768 +DA:162,232 +DA:164,232 +DA:166,232 +DA:168,232 +DA:170,232 +DA:171,232 +DA:172,464 +DA:174,0 +DA:176,232 +DA:178,9536 +DA:180,280 +DA:181,280 +DA:183,280 +DA:191,8 +DA:202,4 +DA:204,32 +DA:205,1088 +DA:206,280 +DA:207,280 +DA:208,280 +DA:209,280 +DA:210,280 DA:212,0 -DA:213,0 -DA:214,0 -DA:216,0 -DA:217,0 -DA:218,0 -LH:68 -LF:83 +DA:215,1112 +DA:217,52 +DA:218,1120 +DA:219,0 +DA:221,1168 +DA:224,4 +DA:225,8 +DA:226,48 +DA:227,0 +DA:229,48 +DA:230,92 +DA:232,52 +DA:233,48 +DA:234,48 +DA:239,4 +DA:241,4 +DA:242,48 +DA:243,48 +DA:245,48 +DA:247,4 +DA:250,0 +DA:262,0 +DA:264,0 +DA:265,0 +DA:266,0 +DA:267,0 +DA:268,0 +DA:269,0 +DA:270,0 +DA:272,0 +DA:275,0 +DA:277,0 +DA:278,0 +DA:279,0 +DA:281,0 +DA:283,0 +DA:284,0 +DA:285,0 +DA:286,0 +DA:287,0 +DA:289,0 +DA:290,0 +DA:291,0 +DA:293,0 +DA:300,0 +DA:301,0 +DA:303,0 +DA:304,0 +DA:305,0 +DA:306,0 +DA:308,0 +DA:310,0 +DA:312,0 +DA:314,0 +DA:328,0 +DA:340,0 +DA:341,0 +DA:342,0 +DA:344,0 +DA:346,0 +LH:97 +LF:146 end_of_record -SF:src/mul.jl -DA:10,142 -DA:18,284 -DA:22,142 -DA:23,142 -DA:24,142 -DA:25,142 -DA:26,142 -DA:27,142 -DA:28,142 -DA:31,142 -DA:34,240 -DA:35,120 -DA:38,1 -DA:40,44 -DA:43,22 -DA:46,22 -DA:49,22 -DA:50,22 -DA:51,22 -DA:54,416 -DA:57,150 -DA:58,30 -DA:60,120 -DA:70,16854 -DA:71,16854 -DA:75,1 -DA:84,1 -DA:85,1 -DA:86,0 -DA:88,1 -DA:89,1 -DA:90,1 -DA:91,1 -DA:92,1 -DA:95,2 -DA:99,1 -DA:103,1 -DA:106,1 -DA:109,1 -DA:112,1 -DA:113,5 -DA:114,24 -DA:115,16 -DA:116,8 -DA:118,19 -DA:119,1 -DA:122,242 -DA:125,121 -DA:128,121 -DA:131,121 -DA:134,121 -DA:135,121 -DA:136,121 -DA:138,121 -DA:139,121 -DA:141,121 -DA:142,121 -DA:144,121 -DA:146,121 -DA:148,121 -DA:151,0 -DA:152,0 -LH:59 -LF:62 +SF:src/partitionedmps.jl +DA:8,172 +DA:9,172 +DA:10,172 +DA:11,1713 +DA:12,3355 +DA:13,174 +DA:15,1881 +DA:18,170 +DA:22,20 +DA:24,24 +DA:26,134 +DA:28,67 +DA:29,134 +DA:30,0 +DA:32,134 +DA:33,1141 +DA:34,2215 +DA:35,67 +DA:38,67 +DA:39,67 +DA:46,19 +DA:47,19 +DA:50,19 +DA:56,2 +DA:61,0 +DA:62,0 +DA:63,0 +DA:66,643 +DA:68,1 +DA:69,1 +DA:72,1 +DA:73,2 +DA:79,494 +DA:80,494 +DA:86,69 +DA:87,69 +DA:94,3 +DA:95,3 +DA:98,4 +DA:99,2 +DA:104,0 +DA:105,0 +DA:113,1 +DA:114,1 +DA:123,2 +DA:131,1 +DA:132,2 +DA:133,10 +DA:134,19 +DA:135,1 +DA:138,38 +DA:147,19 +DA:148,19 +DA:151,38 +DA:162,19 +DA:163,19 +DA:164,19 +DA:165,73 +DA:166,0 +DA:168,73 +DA:169,4 +DA:170,4 +DA:173,69 +DA:174,52 +DA:175,17 +DA:176,17 +DA:178,0 +DA:180,73 +DA:181,19 +DA:184,14 +DA:185,14 +DA:188,13 +DA:189,13 +DA:192,0 +DA:193,0 +DA:203,14 +DA:211,7 +DA:212,7 +DA:213,7 +DA:215,7 +DA:217,7 +DA:218,19 +DA:222,19 +DA:223,33 +DA:224,5 +DA:227,14 +DA:228,14 +DA:230,7 +DA:234,62 +DA:237,31 +DA:238,1215 +DA:242,0 +DA:245,0 +DA:251,8 +DA:254,4 +DA:259,0 +DA:260,0 +DA:269,2 +DA:270,2 +DA:273,19 +DA:274,19 +DA:277,6 +DA:278,3 +LH:89 +LF:103 end_of_record SF:src/patching.jl -DA:6,15 -DA:7,15 -DA:8,15 -DA:12,6 -DA:14,0 +DA:7,4088 +DA:14,2044 DA:15,0 -DA:16,0 -DA:18,0 -DA:21,7960 -DA:26,1 -DA:27,4 -DA:28,1 -DA:29,1 -DA:39,230 -DA:43,230 -DA:44,460 -DA:45,230 -DA:48,460 -DA:49,230 -DA:50,230 -DA:52,230 -DA:53,175 -DA:54,175 -DA:55,1 -DA:57,349 -DA:59,175 -DA:60,175 -DA:62,55 -DA:63,55 -DA:64,220 -DA:65,220 -DA:69,220 -DA:70,385 -DA:71,55 -DA:75,1 -DA:76,1 -DA:77,1 -DA:80,220 -DA:81,220 -DA:85,220 -DA:86,220 -DA:87,220 -DA:88,220 -DA:91,395 -DA:92,395 -DA:93,395 -DA:94,1119 -DA:95,1849 -DA:96,395 -DA:99,1 -DA:102,5 -DA:103,1 -DA:104,2 +DA:20,2044 +DA:23,2044 +DA:27,2040 +DA:29,1020 +DA:31,1020 +DA:32,1020 +DA:33,2040 +DA:34,2040 +DA:42,3060 +DA:44,1020 +DA:50,1021 +DA:51,19447 +DA:52,1021 +DA:53,0 +DA:55,1021 +DA:62,0 +DA:69,0 +DA:72,0 +DA:81,2 +DA:84,1 +DA:85,0 +DA:89,1 +DA:90,2 +DA:91,1 +DA:92,0 +DA:95,1 +DA:96,2 +DA:97,2 +DA:98,2 +DA:99,2 +DA:101,0 DA:105,3 DA:106,1 -DA:109,4 -DA:110,4 -DA:111,4 -DA:113,4 -DA:114,0 -DA:117,4 -DA:118,8 -DA:119,20 -DA:120,20 -DA:121,20 -DA:122,20 -DA:124,4 -DA:132,175 -DA:135,175 -DA:136,175 -DA:138,2878 -DA:139,3053 -DA:140,175 -DA:142,175 -DA:143,350 -DA:144,2878 -DA:145,2878 -DA:146,2367 -DA:147,2367 -DA:148,2367 -DA:149,2367 -DA:151,5581 -DA:154,175 -DA:155,175 -DA:156,350 -DA:157,2878 -DA:158,2878 -DA:159,511 -DA:160,0 -DA:163,5581 -DA:164,175 -DA:165,175 -DA:167,0 -DA:170,350 -DA:171,2878 -DA:172,2367 -DA:174,2044 -DA:175,1022 -DA:176,511 -DA:177,5581 -DA:179,175 -DA:180,175 -DA:183,3980 -DA:184,3980 -DA:188,249 -DA:189,249 -DA:190,611 -DA:191,249 -DA:192,0 -DA:194,249 -DA:198,229 -DA:201,229 -LH:104 -LF:112 +DA:115,0 +DA:118,0 +LH:28 +LF:38 end_of_record -SF:src/projectable_evaluator.jl -DA:10,0 -DA:11,0 -DA:22,0 -DA:25,0 -DA:29,0 -DA:30,0 -DA:34,0 -DA:37,0 -DA:45,0 -DA:48,0 -DA:51,145 -DA:52,145 -DA:65,0 -DA:72,0 -DA:73,0 -DA:80,6 -DA:85,6 -DA:86,0 -DA:89,6 -DA:90,6 -DA:91,6 -DA:93,6 -DA:99,6 -DA:100,18 -DA:106,6 -DA:107,6 -DA:109,6 -DA:116,6 -DA:124,0 -DA:125,0 -DA:128,0 +SF:src/projector.jl +DA:8,27958 +DA:9,48324 +DA:10,136792 +DA:11,2 +DA:13,253216 +DA:14,27956 +DA:18,204 +DA:19,204 +DA:22,1 +DA:23,1 +DA:26,5484 +DA:27,5484 +DA:33,2058 +DA:36,17655 +DA:37,17655 +DA:40,17655 +DA:43,0 +DA:45,9382945 +DA:46,0 +DA:47,1099 +DA:49,4 +DA:50,4 +DA:53,4 +DA:54,8 +DA:57,316637 +DA:62,3 +DA:63,6 +DA:64,3 +DA:65,2 +DA:67,2 +DA:69,1 +DA:72,1 +DA:77,2285826 +DA:78,4571652 +DA:79,4584450 +DA:80,2270269 +DA:82,4628362 +DA:83,15557 +DA:91,3733 +DA:92,3733 +DA:93,7466 +DA:94,21781 +DA:95,21737 +DA:97,43562 +DA:98,3733 +DA:101,0 +DA:102,0 +DA:104,2271830 +DA:105,2271830 +DA:108,1150636 +DA:109,1150636 +DA:115,241 +DA:116,241 +DA:117,2262195 +DA:118,2257126 +DA:119,3 +DA:122,2267044 +DA:123,238 DA:129,0 DA:132,0 DA:133,0 -DA:136,327150 -DA:137,327150 -DA:138,327150 -DA:139,327150 -DA:140,327150 -DA:143,327150 -DA:146,2 -DA:147,2 -DA:148,2 -DA:149,2 -DA:150,2 -DA:153,2 -DA:157,327149 -DA:163,327149 -DA:164,327149 -DA:165,0 -DA:167,327149 -DA:168,327149 -DA:172,142109 -DA:173,142109 -DA:177,1 -DA:182,1 -DA:183,0 -DA:185,1 -DA:188,1 -DA:198,0 -DA:201,0 -DA:203,136 -DA:206,136 -DA:207,136 -DA:209,11 -DA:212,11 -DA:213,11 -DA:217,866865 -DA:219,10 -DA:220,10 -DA:226,122823 -DA:227,122823 -DA:230,288955 -DA:236,288955 -DA:237,0 -DA:239,288955 -DA:240,288955 -DA:241,288955 -DA:242,288955 -DA:244,288955 -DA:247,288955 -DA:248,288955 -DA:250,288955 -DA:253,288955 -DA:254,288955 -DA:260,288955 -DA:261,288955 -DA:267,315958 -DA:268,261952 -DA:271,288955 -DA:273,288955 -DA:276,136 -DA:279,136 -DA:280,136 -DA:286,26 -DA:287,13 -DA:288,26 -DA:289,13 -DA:290,0 -DA:292,13 -DA:294,13 -LH:75 -LF:101 -end_of_record -SF:src/projector.jl -DA:4,1058064 -DA:5,2116128 -DA:6,26758939 -DA:7,53517878 -DA:8,53547274 -DA:9,0 -DA:13,26788335 -DA:14,52459814 -DA:15,1058064 -DA:20,932527 -DA:21,932527 -DA:24,1 -DA:25,1 -DA:28,6669080 -DA:29,6919712 -DA:30,125140 -DA:32,6293308 -DA:35,1403076 -DA:36,23842723 -DA:37,1 -DA:39,1 -DA:40,1 -DA:50,20 -DA:51,1 -DA:52,0 -DA:54,2 -DA:55,2 -DA:56,2 -DA:57,2 -DA:58,4 -DA:59,4 -DA:60,8 -DA:61,5 -DA:62,1 -DA:63,4 -DA:64,3 -DA:65,1 -DA:66,1 -DA:68,0 -DA:70,6 -DA:71,4 -DA:72,6 -DA:74,2 -DA:77,716 -DA:78,716 -DA:79,716 -DA:80,716 -DA:81,1432 -DA:82,10751 -DA:83,21502 -DA:84,16666 -DA:85,14303 -DA:87,2363 -DA:89,22581 -DA:90,10751 -DA:91,20786 -DA:93,716 -DA:96,123861 -DA:97,123861 -DA:98,123861 -DA:99,247722 -DA:100,3127908 -DA:101,306806 -DA:102,1 -DA:104,2821102 -DA:105,10534 -DA:106,454 -DA:108,2810568 -DA:111,6131500 -DA:112,123406 -DA:115,1 -DA:117,122826 -DA:119,201 -DA:120,201 -DA:121,402 -DA:122,2116 -DA:123,133 -DA:124,71 -DA:127,3960 -DA:128,130 -DA:131,1207288 -DA:132,18366438 -DA:133,1207288 -DA:134,31462008 -DA:135,2 -DA:137,30267994 -DA:138,1207286 -DA:141,605375 -DA:142,605375 -DA:146,819 -DA:149,819 -DA:150,819 -DA:152,819 -DA:154,9297 -DA:156,1605 -DA:159,819 -DA:162,10980962 -DA:163,21961924 -DA:164,9833175 -DA:165,2295574 -DA:166,1147787 -DA:168,0 -DA:181,125157 -DA:182,125157 -DA:183,125157 -DA:184,125157 -DA:185,125157 -DA:186,3129914 -DA:187,312205 -DA:189,2817709 -DA:190,2817709 -DA:192,6134671 -DA:193,125157 -DA:196,125155 -DA:200,38203 -DA:201,38212 -DA:207,38203 -LH:113 -LF:117 -end_of_record -SF:src/projtensortrain.jl -DA:15,1869 -DA:16,1869 -DA:17,1869 -DA:18,21071 -DA:20,40273 -DA:21,1869 -DA:24,1869 -DA:25,0 -DA:28,1869 -DA:33,431 -DA:34,431 -DA:35,431 -DA:36,431 -DA:37,0 -DA:39,431 -DA:47,2512 -DA:50,2512 -DA:51,2512 -DA:60,28331 -DA:63,56662 -DA:64,25325 -DA:66,3006 -DA:70,3006 -DA:71,3006 -DA:80,1 -DA:81,1 -DA:84,325 -DA:85,325 -DA:88,105 -DA:91,105 -DA:94,16439 -DA:95,16439 -DA:98,596 -DA:101,596 -DA:103,596 -DA:106,596 -DA:109,38201 -DA:111,0 -DA:112,0 -DA:115,0 -DA:116,0 -DA:119,0 -DA:120,0 -DA:121,0 -DA:124,38201 -DA:130,38201 -DA:131,0 -DA:134,38201 -DA:135,38201 -DA:136,38201 -DA:137,38201 -DA:138,38201 -DA:141,39974 -DA:142,38201 -DA:143,38201 -DA:146,38201 -DA:149,1295 -DA:150,2590 -DA:151,1295 -DA:152,1295 -DA:153,1295 -DA:156,1330 -DA:163,665 -DA:166,665 -DA:169,212 -DA:173,453 -DA:174,906 -DA:177,453 -DA:178,453 -DA:179,18 -DA:180,18 -DA:183,453 -DA:188,471 -DA:189,471 -DA:191,471 -DA:192,7492 -DA:194,2705 -DA:196,1041 -DA:197,1041 -DA:198,1626 -DA:199,331 -DA:201,1295 -DA:202,1626 -DA:203,3746 -DA:206,238 -DA:209,119 -DA:210,119 -DA:211,119 -DA:214,16 -DA:215,16 -DA:218,202 -DA:221,101 -DA:229,89 -DA:230,89 -DA:233,308 -DA:242,154 -DA:243,154 -DA:244,154 -DA:245,154 -DA:246,154 -DA:264,91 -DA:265,91 -DA:266,534 -DA:267,91 -DA:268,1068 -DA:269,298 -DA:270,298 -DA:271,472 -DA:272,236 -DA:275,236 -DA:277,0 -DA:279,236 -DA:280,977 -DA:282,91 -DA:283,91 -DA:284,263 -DA:285,172 -DA:286,172 -DA:288,91 -DA:290,91 -DA:293,172 -DA:294,172 -DA:295,172 -DA:296,806 -DA:297,634 -DA:298,634 -DA:299,170 -DA:300,170 -DA:301,464 -DA:302,228 -DA:303,228 -DA:304,241 -DA:305,236 -DA:306,236 -DA:307,236 -DA:308,472 -DA:309,472 -DA:310,472 -DA:312,0 -DA:314,634 -DA:315,172 -DA:318,726 -DA:319,726 -DA:320,726 -DA:323,2112 -DA:324,2112 -DA:327,236 -DA:328,236 -DA:335,252 -DA:338,126 -DA:339,126 -DA:340,126 -DA:341,1232 -DA:344,616 -DA:345,72 -DA:346,24 -DA:347,24 -DA:348,24 -DA:350,592 -DA:351,592 -DA:352,592 -DA:354,1106 -DA:356,126 -DA:357,126 -DA:358,364 -DA:359,602 -DA:361,126 -DA:362,0 -DA:364,126 -DA:366,126 -DA:369,4 -DA:370,4 -DA:371,4 -DA:372,4 -LH:161 -LF:174 +DA:134,0 +DA:135,0 +DA:138,0 +DA:139,0 +DA:142,0 +DA:143,0 +DA:149,8 +DA:150,4 +DA:151,2 +DA:153,4 +DA:155,4 +DA:156,2 +DA:158,6 +DA:160,8 +DA:161,4 +DA:162,0 +DA:164,4 +DA:165,2 +DA:167,2 +DA:171,4 +DA:173,4 +LH:68 +LF:82 end_of_record -SF:src/tree.jl -DA:4,116 -DA:9,483 -DA:12,6 -DA:13,6 -DA:14,6 -DA:17,110 -DA:18,110 -DA:19,110 -DA:23,367 -DA:24,367 -DA:26,367 -DA:29,367 -DA:30,367 -DA:31,651 -DA:33,133 -DA:35,133 -DA:36,96 -DA:39,133 -DA:40,183 -DA:42,367 -DA:48,116 -DA:49,116 -DA:50,116 +SF:src/subdomainmps.jl +DA:8,17446 +DA:9,17446 +DA:12,17446 +DA:13,17446 +DA:17,1884 +DA:19,20799 +DA:20,3353 +DA:22,0 +DA:23,2047 +DA:25,17446 +DA:26,17446 +DA:27,17446 +DA:28,28007 +DA:29,77805 +DA:30,5355 +DA:32,145049 +DA:33,17446 +DA:36,5481 +DA:37,5481 +DA:41,868 +DA:43,300645 +DA:44,1407521 +DA:48,300645 +DA:49,300645 +DA:50,1407521 DA:51,0 -DA:53,7380 -DA:54,116 -DA:55,0 -DA:56,116 -DA:57,0 -DA:59,116 -DA:66,1 -DA:67,1 -DA:68,1 +DA:53,300645 +DA:55,300645 +DA:58,11951 +DA:59,11951 +DA:60,11951 +DA:61,128 +DA:64,11823 DA:69,0 -DA:71,1 +DA:72,0 DA:75,0 -DA:76,0 -DA:77,0 -DA:80,0 -DA:81,0 -DA:82,0 -DA:86,132 -DA:87,132 -DA:89,132 -DA:90,248 -DA:91,35 -DA:92,35 -DA:93,34 -DA:95,1 -DA:97,53 -DA:98,131 -DA:103,40 -DA:106,211 -DA:107,211 -DA:108,171 -DA:109,171 -DA:110,171 -DA:113,171 -DA:114,302 -DA:118,6 -DA:119,6 -DA:120,12 -DA:121,28 -DA:122,50 -DA:123,6 -DA:126,40 -DA:127,40 -DA:130,1 -DA:131,1 -LH:59 -LF:69 +DA:78,0 +DA:81,3411 +DA:82,3411 +DA:85,11 +DA:88,11 +DA:91,142 +DA:94,142 +DA:97,179556 +DA:99,179556 +DA:102,17446 +DA:103,34892 +DA:106,114 +DA:107,114 +DA:108,114 +DA:112,0 +DA:113,0 +DA:116,10 +DA:117,4 +DA:123,0 +DA:124,0 +DA:125,0 +DA:127,0 +DA:130,0 +DA:136,6 +DA:137,3 +DA:140,34474 +DA:141,34474 +DA:144,6456 +DA:150,2088 +DA:151,4176 +DA:153,2088 +DA:154,2088 +DA:155,2088 +DA:156,2088 +DA:159,6614 +DA:160,3307 +DA:161,1219 +DA:162,2088 +DA:163,0 +DA:164,0 +DA:168,0 +DA:169,2088 +DA:170,4368 +DA:171,2280 +DA:175,2088 +DA:176,4368 +DA:177,2088 +DA:179,0 +DA:183,2438 +DA:186,1219 +DA:189,6614 +DA:192,3307 +DA:198,16 +DA:199,16 +DA:202,77 +DA:203,77 +DA:206,0 +DA:207,0 +DA:210,2974 +DA:211,1487 +DA:214,93 +DA:215,93 +DA:216,0 +DA:218,93 +DA:219,93 +DA:222,93 +DA:223,93 +DA:226,188 +DA:229,94 +DA:230,94 +DA:231,70 +DA:232,70 +DA:233,70 +DA:235,94 +DA:236,94 +DA:237,70 +DA:238,32 +DA:239,32 +DA:240,8 +DA:241,8 +DA:244,70 +DA:246,94 +DA:249,0 +DA:250,0 +DA:253,188 +DA:256,94 +DA:259,1 +DA:260,1 +DA:261,1 +DA:263,1 +DA:267,1 +DA:268,1 +DA:269,3 +DA:270,0 +DA:272,3 +DA:274,1 +DA:277,42 +DA:280,42 +DA:281,42 +DA:282,371 +DA:283,48 +DA:284,48 +DA:285,48 +DA:287,48 +DA:289,48 +DA:290,700 +DA:292,42 +DA:294,83 +DA:295,201 +DA:296,201 +DA:297,361 +DA:298,42 +DA:301,1 +DA:302,1 +DA:303,1 +DA:306,0 +DA:307,0 +LH:128 +LF:153 end_of_record SF:src/util.jl -DA:2,2 -DA:3,2 -DA:4,1 -DA:6,1 -DA:7,1 -DA:10,3548 -DA:11,3548 -DA:14,22058823 -DA:15,22058823 -DA:16,22058823 -DA:17,1 -DA:19,22058822 -DA:23,1585682 -DA:24,1585682 -DA:27,14396011 -DA:28,28792022 -DA:29,6658366 -DA:31,7737645 -DA:32,15475290 -DA:33,7737645 -DA:36,441899 -DA:37,441899 -DA:40,1 -DA:41,3 -DA:44,146 -DA:45,146 -DA:46,146 -DA:47,735 -DA:48,735 -DA:49,745 -DA:50,14 -DA:52,721 -DA:53,1324 -DA:54,146 -DA:57,14538 -DA:59,727 -DA:65,727 -DA:67,727 -DA:69,727 -DA:71,727 -DA:72,727 -DA:74,727 -DA:75,727 -DA:77,1454 -DA:82,1454 -DA:88,727 -DA:94,1 -DA:95,1 -DA:96,1 -DA:97,1 -LH:50 -LF:50 +DA:2,0 +DA:3,0 +DA:4,0 +DA:6,0 +DA:7,0 +DA:10,0 +DA:11,0 +DA:14,0 +DA:15,0 +DA:18,0 +DA:20,0 +DA:26,0 +DA:28,0 +DA:30,0 +DA:32,0 +DA:33,0 +DA:35,0 +DA:36,0 +DA:38,0 +DA:43,0 +DA:49,0 +DA:55,0 +DA:56,0 +DA:57,0 +DA:58,0 +DA:61,146 +DA:62,73 +DA:63,73 +DA:64,73 +DA:65,73 +DA:66,73 +DA:67,146 +DA:68,219 +DA:69,73 +DA:74,114 +DA:75,114 +DA:77,114 +DA:80,114 +DA:81,228 +DA:82,114 +DA:83,114 +DA:84,1452 +DA:85,2258 +DA:86,332 +DA:88,3852 +DA:89,1926 +DA:90,1926 +DA:91,1926 +DA:92,1926 +DA:93,1926 +DA:94,1926 +DA:95,1926 +DA:96,2258 +DA:98,1452 +DA:99,1338 +DA:101,1566 +DA:103,1452 +DA:104,2790 +DA:105,114 +DA:106,114 +DA:110,11 +DA:121,10 +DA:124,5 +DA:127,5 +DA:128,5 +DA:129,5 +DA:130,0 +DA:131,0 +DA:133,0 +DA:138,5 +DA:139,5 +DA:140,41 +DA:141,41 +DA:142,41 +DA:143,5 +DA:144,36 +DA:145,0 +DA:147,36 +DA:148,36 +DA:149,5 +DA:152,10 +DA:155,5 +DA:156,5 +DA:157,5 +DA:160,2 +DA:163,1 +DA:165,1 +DA:166,1 +DA:167,6 +DA:168,9 +DA:169,12 +DA:170,11 +DA:172,1 +DA:174,1 +DA:175,10 +DA:176,1 +DA:177,0 +DA:178,0 +DA:181,0 +DA:185,1 +DA:186,1 +DA:187,4 +DA:188,40 +DA:189,4 +DA:190,1 +DA:191,3 +DA:192,0 +DA:195,3 +DA:196,3 +DA:197,1 +DA:200,0 +DA:203,0 +DA:204,0 +DA:205,0 +DA:210,96 +DA:213,48 +DA:214,48 +DA:220,1 +DA:221,1 +DA:222,1 +DA:224,1 +DA:226,2 +DA:227,6 +DA:228,3 +DA:229,5 +DA:231,1 +DA:234,48 +DA:235,48 +DA:236,48 +DA:237,48 +DA:238,48 +DA:239,48 +DA:240,96 +DA:241,144 +DA:242,48 +DA:248,0 +DA:249,0 +DA:250,0 +DA:251,0 +DA:252,0 +DA:253,0 +DA:256,0 +DA:257,0 +DA:258,0 +DA:259,0 +DA:260,0 +DA:261,0 +DA:264,2 +DA:269,2 +DA:270,4 +DA:271,20 +DA:272,38 +DA:273,2 +DA:276,20 +DA:279,20 +DA:280,260 +DA:281,300 +DA:282,40 +DA:283,0 +DA:285,20 +DA:286,0 +DA:288,20 +DA:289,20 +DA:290,20 +DA:291,20 +LH:114 +LF:165 end_of_record diff --git a/src/contract.jl b/src/contract.jl index 40daffa..5627113 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -2,10 +2,14 @@ _alg_map = Dict( ITensors.Algorithm(alg) => alg for alg in ["directsum", "densitymatrix", "fit", "naive"] ) - +""" +Contraction of two SubDomainMPSs. +Only if the shared projected indices overlap the contraction is non-vanishing. +""" function contract( M1::SubDomainMPS, M2::SubDomainMPS; alg, kwargs... )::Union{SubDomainMPS,Nothing} + # If the SubDomainMPS don't overlap they cannot be contracted. if !hasoverlap(M1.projector, M2.projector) return nothing end @@ -24,6 +28,10 @@ function _projector_after_contract(M1::SubDomainMPS, M2::SubDomainMPS) sites2 = _allsites(M2) external_sites = setdiff(union(sites1, sites2), intersect(sites1, sites2)) + # If the SubDomainMPS don't overlap they cannot be contracted -> no final projector + if !hasoverlap(M1.projector, M2.projector) + return nothing, external_sites + end proj = deepcopy(M1.projector.data) empty!(proj) @@ -40,6 +48,7 @@ function _projector_after_contract(M1::SubDomainMPS, M2::SubDomainMPS) return Projector(proj), external_sites end +# Check for newly projected sites to be only external sites. function _is_externalsites_compatible_with_projector(external_sites, projector) for s in keys(projector) if !(s ∈ external_sites) @@ -59,7 +68,6 @@ function projcontract( alg="zipup", cutoff=default_cutoff(), maxdim=default_maxdim(), - verbosity=0, kwargs..., )::Union{Nothing,SubDomainMPS} # Project M1 and M2 to `proj` before contracting @@ -75,16 +83,13 @@ function projcontract( error("The projector contains projection onto a site that is not an external site.") end - # t1 = time_ns() r = contract(M1, M2; alg, cutoff, maxdim, kwargs...) - # t2 = time_ns() - #println("contract: $((t2 - t1)*1e-9) s") return r end """ -Project two SubDomainMPS objects to `proj` before contracting them. -The results are summed. +Project SubDomainMPS vectors to `proj` before computing all possible pairwise contractions of the elements. +The results are summed or patch-summed. """ function projcontract( M1::AbstractVector{SubDomainMPS}, @@ -99,44 +104,10 @@ function projcontract( )::Union{Nothing,Vector{SubDomainMPS}} results = SubDomainMPS[] - # Precollect the pairs for threading - pairinfo = vec([(m1, m2, maxlinkdim(m1) * maxlinkdim(m2)) for m1 in M1, m2 in M2]) - # Heavy contraction first - sort!(pairinfo; by=x -> x[3], rev=true) - # Lock for threaded computation - local_lock = ReentrantLock() - - if Threads.nthreads() > 1 - nT = nthreads() - chunked_pairs = [Vector{Tuple{SubDomainMPS,SubDomainMPS}}() for _ in 1:nT] - # Equally divide expensive computations btw threads - for (i, (m1, m2, _)) in enumerate(pairinfo) - t = ((i - 1) % nT) + 1 - push!(chunked_pairs[t], (m1, m2)) - end - - @threads for t in 1:nT - local_buffer = SubDomainMPS[] - - for (m1, m2) in chunked_pairs[t] - r = projcontract(m1, m2, proj; alg, cutoff, maxdim, kwargs...) - - if r !== nothing - push!(local_buffer, r) # Thread-local accumulation - end - end - - # Lock is held only briefly to merge partial results - lock(local_lock) do - append!(results, local_buffer) - end - end - else - for (m1, m2, _) in pairinfo - r = projcontract(m1, m2, proj; alg, cutoff, maxdim, kwargs...) - if r !== nothing - push!(results, r) - end + for m1 in M1, m2 in M2 + r = projcontract(m1, m2, proj; alg, cutoff, maxdim, kwargs...) + if r !== nothing + push!(results, r) end end @@ -153,14 +124,12 @@ function projcontract( else [_add(results...; alg=alg_sum, cutoff, maxdim, kwargs...)] end - #T3 = time_ns() - #println("mul: $((T2 - T1)*1e-9) s") - #println("add: $((T3 - T2)*1e-9) s") + return res end """ -Contract two Blocked MPS objects. +Contract two PartitionedMPSs MPS objects. At each site, the objects must share at least one site index. """ @@ -175,9 +144,7 @@ function contract( kwargs..., )::Union{PartitionedMPS} M = PartitionedMPS() - if parallel == :distributed_thread - return parallel_contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, kwargs...) - elseif parallel == :distributed + if parallel == :distributed return distribute_contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, kwargs...) elseif parallel == :serial return contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, kwargs...) @@ -186,71 +153,42 @@ function contract( end end -""" -Contract two PartitionedMPS objects. - -Existing blocks `M` in the resulting PartitionedMPS will be overwritten if `overwrite=true`. -""" -function contract!( - M::PartitionedMPS, - M1::PartitionedMPS, - M2::PartitionedMPS; - alg="zipup", - cutoff=default_cutoff(), - maxdim=default_maxdim(), - patchorder=Index[], - overwrite=true, - kwargs..., -)::Union{PartitionedMPS} - blocks_to_sets = Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() - - for m1 in values(M1), m2 in values(M2) - if hasoverlap(m1.projector, m2.projector) - block = add_entry!(blocks_to_sets, _projector_after_contract(m1, m2)[1]) - if haskey(blocks_to_sets, block) - set1, set2 = blocks_to_sets[block] - push!(set1, m1) - push!(set2, m2) +function add_entry!( + dict::Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}, proj::Projector +) + # Iterate over a copy of keys to avoid modifying the dict while looping. + for existing in collect(keys(dict)) + if hasoverlap(existing, proj) + fused_proj = existing | proj + # Save the current value for the overlapping key. + val = dict[existing] + # Remove the old key (this deletes its associated value). + delete!(dict, existing) + # Recursively update with the fused projector. + new_key = add_entry!(dict, fused_proj) + # If new_key is already present, merge the values; otherwise, insert the saved value. + if haskey(dict, new_key) + old_val = dict[new_key] + dict[new_key] = (union(old_val[1], val[1]), union(old_val[2], val[2])) else - blocks_to_sets[block] = (Set([m1]), Set([m2])) + dict[new_key] = val end + return new_key end end - - for b1 in keys(blocks_to_sets), b2 in keys(blocks_to_sets) - if b1 != b2 && hasoverlap(b1, b2) - error("After contraction, projectors must not overlap.") - end - end - - # Builds tasks to parallelise - tasks = Vector{Tuple{Projector,Vector{SubDomainMPS},Vector{SubDomainMPS}}}() - for (proj, (set1, set2)) in blocks_to_sets - if haskey(M.data, proj) && !overwrite - continue - end - push!(tasks, (proj, collect(set1), collect(set2))) - end - - function process_task(task) - proj, M1_subs, M2_subs = task - return projcontract( - M1_subs, M2_subs, proj; alg, cutoff, maxdim, patchorder, kwargs... - ) - end - - results = map(process_task, tasks) - - for res in results - if res !== nothing - append!(M, res) - end + # If no overlapping key is found, then ensure proj is in the dictionary. + if !haskey(dict, proj) + dict[proj] = (Set{SubDomainMPS}(), Set{SubDomainMPS}()) end - - return M + return proj end -function parallel_contract!( +""" +Contract two PartitionedMPS objects. + +Existing patches `M` in the resulting PartitionedMPS will be overwritten if `overwrite=true`. +""" +function contract!( M::PartitionedMPS, M1::PartitionedMPS, M2::PartitionedMPS; @@ -261,22 +199,22 @@ function parallel_contract!( overwrite=true, kwargs..., )::Union{PartitionedMPS} - blocks_to_sets = Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() + patches_to_sets = Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() for m1 in values(M1), m2 in values(M2) if hasoverlap(m1.projector, m2.projector) - block = add_entry!(blocks_to_sets, _projector_after_contract(m1, m2)[1]) - if haskey(blocks_to_sets, block) - set1, set2 = blocks_to_sets[block] + patch = add_entry!(patches_to_sets, _projector_after_contract(m1, m2)[1]) + if haskey(patches_to_sets, patch) + set1, set2 = patches_to_sets[patch] push!(set1, m1) push!(set2, m2) else - blocks_to_sets[block] = (Set([m1]), Set([m2])) + patches_to_sets[patch] = (Set([m1]), Set([m2])) end end end - for b1 in keys(blocks_to_sets), b2 in keys(blocks_to_sets) + for b1 in keys(patches_to_sets), b2 in keys(patches_to_sets) if b1 != b2 && hasoverlap(b1, b2) error("After contraction, projectors must not overlap.") end @@ -284,7 +222,7 @@ function parallel_contract!( # Builds tasks to parallelise tasks = Vector{Tuple{Projector,Vector{SubDomainMPS},Vector{SubDomainMPS}}}() - for (proj, (set1, set2)) in blocks_to_sets + for (proj, (set1, set2)) in patches_to_sets if haskey(M.data, proj) && !overwrite continue end @@ -298,9 +236,9 @@ function parallel_contract!( ) end - results_parallel = pmap(process_task, tasks) + results = map(process_task, tasks) - for res in results_parallel + for res in results if res !== nothing append!(M, res) end @@ -309,36 +247,6 @@ function parallel_contract!( return M end -function add_entry!( - dict::Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}, proj::Projector -) - # Iterate over a copy of keys to avoid modifying the dict while looping. - for existing in collect(keys(dict)) - if hasoverlap(existing, proj) - fused_proj = existing | proj - # Save the current value for the overlapping key. - val = dict[existing] - # Remove the old key (this deletes its associated value). - delete!(dict, existing) - # Recursively update with the fused projector. - new_key = add_entry!(dict, fused_proj) - # If new_key is already present, merge the values; otherwise, insert the saved value. - if haskey(dict, new_key) - old_val = dict[new_key] - dict[new_key] = (union(old_val[1], val[1]), union(old_val[2], val[2])) - else - dict[new_key] = val - end - return new_key - end - end - # If no overlapping key is found, then ensure proj is in the dictionary. - if !haskey(dict, proj) - dict[proj] = (Set{SubDomainMPS}(), Set{SubDomainMPS}()) - end - return proj -end - function distribute_contract!( M::PartitionedMPS, M1::PartitionedMPS, @@ -351,29 +259,29 @@ function distribute_contract!( overwrite=true, kwargs..., )::Union{PartitionedMPS} - blocks_to_sets = Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() + patches_to_sets = Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() for m1 in values(M1), m2 in values(M2) if hasoverlap(m1.projector, m2.projector) - block = add_entry!(blocks_to_sets, _projector_after_contract(m1, m2)[1]) - if haskey(blocks_to_sets, block) - set1, set2 = blocks_to_sets[block] + patch = add_entry!(patches_to_sets, _projector_after_contract(m1, m2)[1]) + if haskey(patches_to_sets, patch) + set1, set2 = patches_to_sets[patch] push!(set1, m1) push!(set2, m2) else - blocks_to_sets[block] = (Set([m1]), Set([m2])) + patches_to_sets[patch] = (Set([m1]), Set([m2])) end end end - for b1 in keys(blocks_to_sets), b2 in keys(blocks_to_sets) + for b1 in keys(patches_to_sets), b2 in keys(patches_to_sets) if b1 != b2 && hasoverlap(b1, b2) error("After contraction, projectors must not overlap.") end end tasks = Vector{Tuple{Projector,SubDomainMPS,SubDomainMPS}}() - for (proj, (set1, set2)) in blocks_to_sets + for (proj, (set1, set2)) in patches_to_sets for subdmps1 in set1, subdmps2 in set2 if haskey(M.data, proj) && !overwrite continue @@ -392,18 +300,18 @@ function distribute_contract!( results = pmap(task -> process_task(task; alg, cutoff, maxdim, kwargs...), tasks) valid_results = filter(x -> x[2] !== nothing, results) - block_group = Dict{Projector,Vector{SubDomainMPS}}() + patch_group = Dict{Projector,Vector{SubDomainMPS}}() for (b, subdmps) in valid_results - if haskey(block_group, b) - push!(block_group[b], subdmps) + if haskey(patch_group, b) + push!(patch_group[b], subdmps) else - block_group[b] = [subdmps] + patch_group[b] = [subdmps] end end - block_group_array = collect(block_group) + patch_group_array = collect(patch_group) - function sum_blocks(group; patchorder, alg_sum, cutoff, maxdim, kwargs...) + function sum_patches(group; patchorder, alg_sum, cutoff, maxdim, kwargs...) b, subdmps_list = group if length(subdmps_list) == 1 return [subdmps_list[1]] @@ -418,7 +326,7 @@ function distribute_contract!( end summed_patches = pmap( - group -> sum_blocks( + group -> sum_patches( group; patchorder=patchorder, alg_sum=alg_sum, @@ -426,7 +334,7 @@ function distribute_contract!( maxdim=maxdim, kwargs..., ), - block_group_array, + patch_group_array, ) for res in summed_patches diff --git a/test/bak/conversion_tests.jl b/test/bak/conversion_tests.jl index 9f8ae02..1a095cf 100644 --- a/test/bak/conversion_tests.jl +++ b/test/bak/conversion_tests.jl @@ -5,10 +5,10 @@ using ITensors import TensorCrossInterpolation as TCI import TCIAlgorithms as TCIA using TCIITensorConversion -import PartitionedMPSs: PartitionedMPSs, SubDomainMPS, PartitionedMPS -#import FastMPOContractions as FMPOC -#import Quantics: asMPO -#using Quantics: Quantics +import PartitionedMPSs: PartitionedMPSs + +conversion_file = normpath(joinpath(dirname(pathof(PartitionedMPSs)), "bak/conversion.jl")) +include(conversion_file) @testset "conversion.jl" begin @testset "TCIA.ProjTensorTrain => SubDomainMPS" begin diff --git a/test/runtests.jl b/test/runtests.jl index 97e7ed9..11f900e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -import PartitionedMPSs as PM +using PartitionedMPSs: PartitionedMPSs using Random using ITensors using ITensorMPS @@ -13,5 +13,4 @@ include("contract_tests.jl") include("patching_tests.jl") include("util_tests.jl") include("automul_tests.jl") - -# include("bak/conversion_tests.jl") +include("bak/conversion_tests.jl") From f7f27e22da2fb3f26c92de5d57757c92b1cdc2ba Mon Sep 17 00:00:00 2001 From: Gianluca Grosso Date: Wed, 16 Apr 2025 22:59:52 +0200 Subject: [PATCH 09/11] Reformat distribute contract --- benchmark/distributecontract.jl | 86 +++++ coverage/lcov.info | 610 ++++++++++++++++---------------- src/contract.jl | 280 +++++++-------- src/partitionedmps.jl | 8 +- src/projector.jl | 5 +- 5 files changed, 521 insertions(+), 468 deletions(-) create mode 100644 benchmark/distributecontract.jl diff --git a/benchmark/distributecontract.jl b/benchmark/distributecontract.jl new file mode 100644 index 0000000..4f9229c --- /dev/null +++ b/benchmark/distributecontract.jl @@ -0,0 +1,86 @@ +using Distributed +using BenchmarkTools +using Random + +nworkers = 2 + +if nworkers > nprocs() - 1 + addprocs(nworkers) +end + +library_dir = normpath(joinpath(dirname(pathof(PartitionedMPSs)))) + +@everywhere begin + using Pkg + Pkg.activate(library_dir) + Pkg.instantiate() + Pkg.precompile() +end + +@everywhere begin + import PartitionedMPSs: + PartitionedMPSs, + contract, + PartitionedMPS, + SubDomainMPS, + Projector, + project, + _add, + projcontract +end + +using ITensors, ITensorMPS + +random_mpo_file = normpath(joinpath(dirname(pathof(PartitionedMPSs)), "../test/_util.jl")) + +Random.seed!(1234) +R = 5 +d = 2 +L = 5 + +sites_x = [Index(d, "Qubit,x=$x") for x in 1:R] +sites_y = [Index(d, "Qubit,y=$y") for y in 1:R] +sites_s = [Index(d, "Qubit,s=$s") for s in 1:R] + +sites_xs = collect(Iterators.flatten(zip(sites_x, sites_s))) +sites_sy = collect(Iterators.flatten(zip(sites_s, sites_y))) + +Ψ_l = ITensorMPS.convert(MPS, _random_mpo([[x] for x in sites_xs]; linkdims=L)) +Ψ_r = ITensorMPS.convert(MPS, _random_mpo([[x] for x in sites_sy]; linkdims=L)) + +proj_lev_l = 4 +proj_l = vec([ + Dict(zip(sites_xs, combo)) for + combo in Iterators.product((1:d for _ in 1:proj_lev_l)...) +]) + +proj_lev_r = 6 +proj_r = vec([ + Dict(zip(sites_sy, combo)) for + combo in Iterators.product((1:d for _ in 1:proj_lev_r)...) +]) + +partΨ_l = PartitionedMPS(project.(Ref(Ψ_l), proj_l)) +partΨ_r = PartitionedMPS(project.(Ref(Ψ_r), proj_r)) + +# ------------------------------------------------------------------ +# Warm‑up (compile both code paths) -------------------------------- +# ------------------------------------------------------------------ +println("warming up …") +serial_warm = contract(partΨ_l, partΨ_r; parallel=:serial) +dist_warm = contract(partΨ_l, partΨ_r; parallel=:distributed) +@assert MPS(serial_warm) ≈ MPS(dist_warm) # sanity check + +# ------------------------------------------------------------------ +# Benchmark -------------------------------------------------------- +# ------------------------------------------------------------------ +println("\nbenchmarking …") +serial_time = @belapsed contract($partΨ_l, $partΨ_r; parallel=:serial) +dist_time = @belapsed contract($partΨ_l, $partΨ_r; parallel=:distributed) + +println("\n---------------- results ----------------") +println("workers : $(nprocs() - 1)") +println("serial time : $(round(serial_time; digits = 4)) s") +println("distributed time : $(round(dist_time; digits = 4)) s") +println("speed‑up : $(round(serial_time / dist_time; digits = 2))×") +println("-----------------------------------------") diff --git a/coverage/lcov.info b/coverage/lcov.info index 62bf682..2eaeeca 100644 --- a/coverage/lcov.info +++ b/coverage/lcov.info @@ -78,10 +78,10 @@ DA:101,10 DA:102,10 DA:103,0 DA:105,10 -DA:106,5 -DA:107,5 -DA:109,5 -DA:110,5 +DA:106,7 +DA:107,7 +DA:109,3 +DA:110,3 DA:113,10 DA:114,1 LH:40 @@ -112,19 +112,19 @@ LH:20 LF:20 end_of_record SF:src/contract.jl -DA:9,2198 -DA:13,1099 -DA:14,800 +DA:9,630 +DA:13,315 +DA:14,16 DA:16,598 DA:18,299 DA:19,299 DA:22,299 -DA:26,1676 -DA:27,1676 -DA:28,1676 -DA:30,1676 -DA:32,1676 -DA:33,800 +DA:26,1700 +DA:27,1700 +DA:28,1700 +DA:30,1700 +DA:32,1700 +DA:33,824 DA:36,876 DA:37,14001 DA:39,876 @@ -134,132 +134,115 @@ DA:43,17237 DA:44,2459 DA:46,17237 DA:48,876 -DA:52,1097 -DA:53,2194 -DA:54,40906 +DA:52,313 +DA:53,626 +DA:54,10762 DA:55,0 -DA:57,10582 -DA:58,1097 -DA:64,2386 -DA:74,1193 -DA:75,1193 -DA:76,2322 +DA:57,2838 +DA:58,313 +DA:64,818 +DA:74,409 +DA:75,409 +DA:76,754 DA:77,96 -DA:80,1394 -DA:82,1097 +DA:80,610 +DA:82,313 DA:83,0 -DA:86,1097 -DA:87,1097 -DA:94,112 -DA:105,56 -DA:107,56 -DA:108,1192 -DA:109,1192 -DA:110,296 -DA:112,1192 -DA:114,56 +DA:86,313 +DA:87,313 +DA:94,16 +DA:105,8 +DA:107,8 +DA:108,128 +DA:109,128 +DA:110,16 +DA:112,128 +DA:114,8 DA:115,0 -DA:118,56 -DA:119,8 -DA:122,48 -DA:123,4 -DA:125,92 -DA:128,48 -DA:136,8 -DA:146,4 -DA:147,4 -DA:148,0 -DA:149,4 -DA:150,4 -DA:152,0 -DA:156,512 -DA:160,512 -DA:161,9768 -DA:162,232 -DA:164,232 -DA:166,232 -DA:168,232 -DA:170,232 -DA:171,232 -DA:172,464 -DA:174,0 -DA:176,232 -DA:178,9536 -DA:180,280 -DA:181,280 -DA:183,280 -DA:191,8 +DA:118,8 +DA:119,0 +DA:122,8 +DA:123,0 +DA:125,16 +DA:128,8 +DA:133,512 +DA:137,512 +DA:138,9849 +DA:139,232 +DA:141,232 +DA:143,232 +DA:145,232 +DA:147,232 +DA:148,464 +DA:150,0 +DA:152,232 +DA:154,9617 +DA:156,280 +DA:157,280 +DA:159,280 +DA:163,8 +DA:169,4 +DA:172,32 +DA:173,1368 +DA:174,1088 +DA:175,280 +DA:176,280 +DA:177,280 +DA:179,0 +DA:182,1112 +DA:185,52 +DA:186,1120 +DA:187,0 +DA:189,1168 +DA:192,4 +DA:193,8 +DA:194,48 +DA:195,0 +DA:197,48 +DA:198,280 +DA:199,280 +DA:200,92 DA:202,4 -DA:204,32 -DA:205,1088 -DA:206,280 -DA:207,280 -DA:208,280 -DA:209,280 -DA:210,280 -DA:212,0 -DA:215,1112 -DA:217,52 -DA:218,1120 -DA:219,0 -DA:221,1168 -DA:224,4 -DA:225,8 -DA:226,48 -DA:227,0 -DA:229,48 -DA:230,92 -DA:232,52 -DA:233,48 -DA:234,48 -DA:239,4 -DA:241,4 -DA:242,48 -DA:243,48 -DA:245,48 -DA:247,4 -DA:250,0 -DA:262,0 -DA:264,0 -DA:265,0 -DA:266,0 -DA:267,0 -DA:268,0 -DA:269,0 -DA:270,0 -DA:272,0 -DA:275,0 -DA:277,0 -DA:278,0 -DA:279,0 -DA:281,0 -DA:283,0 -DA:284,0 -DA:285,0 -DA:286,0 -DA:287,0 -DA:289,0 -DA:290,0 -DA:291,0 -DA:293,0 -DA:300,0 -DA:301,0 -DA:303,0 -DA:304,0 +DA:210,8 +DA:220,4 +DA:221,4 +DA:229,8 +DA:243,4 +DA:246,564 +DA:247,280 +DA:248,280 +DA:252,4 +DA:253,4 +DA:254,280 +DA:256,0 +DA:257,0 +DA:261,0 +DA:265,284 +DA:270,4 +DA:271,4 +DA:272,280 +DA:273,232 +DA:275,48 +DA:277,280 +DA:280,100 +DA:281,48 +DA:282,8 +DA:284,40 +DA:285,4 +DA:287,76 +DA:289,40 +DA:293,4 +DA:294,8 +DA:295,48 DA:305,0 DA:306,0 -DA:308,0 -DA:310,0 -DA:312,0 -DA:314,0 -DA:328,0 -DA:340,0 -DA:341,0 -DA:342,0 -DA:344,0 -DA:346,0 -LH:97 -LF:146 +DA:320,4 +DA:321,48 +DA:322,96 +DA:324,48 +DA:326,4 +LH:115 +LF:129 end_of_record SF:src/partitionedmps.jl DA:8,172 @@ -272,99 +255,99 @@ DA:15,1881 DA:18,170 DA:22,20 DA:24,24 -DA:26,134 +DA:26,245 DA:28,67 DA:29,134 DA:30,0 -DA:32,134 -DA:33,1141 -DA:34,2215 -DA:35,67 -DA:38,67 -DA:39,67 -DA:46,19 -DA:47,19 -DA:50,19 -DA:56,2 -DA:61,0 -DA:62,0 +DA:34,134 +DA:35,1141 +DA:36,2215 +DA:37,67 +DA:40,67 +DA:41,67 +DA:48,19 +DA:49,19 +DA:52,19 +DA:58,2 DA:63,0 -DA:66,643 -DA:68,1 -DA:69,1 -DA:72,1 -DA:73,2 -DA:79,494 -DA:80,494 -DA:86,69 -DA:87,69 -DA:94,3 -DA:95,3 -DA:98,4 -DA:99,2 -DA:104,0 -DA:105,0 -DA:113,1 -DA:114,1 -DA:123,2 -DA:131,1 -DA:132,2 -DA:133,10 -DA:134,19 -DA:135,1 -DA:138,38 -DA:147,19 -DA:148,19 -DA:151,38 -DA:162,19 -DA:163,19 +DA:64,0 +DA:65,0 +DA:68,643 +DA:70,1 +DA:71,1 +DA:74,1 +DA:75,2 +DA:81,494 +DA:82,494 +DA:88,69 +DA:89,69 +DA:96,3 +DA:97,3 +DA:100,4 +DA:101,2 +DA:106,0 +DA:107,0 +DA:115,1 +DA:116,1 +DA:125,2 +DA:133,1 +DA:134,2 +DA:135,10 +DA:136,19 +DA:137,1 +DA:140,38 +DA:149,19 +DA:150,19 +DA:153,38 DA:164,19 -DA:165,73 -DA:166,0 -DA:168,73 -DA:169,4 -DA:170,4 -DA:173,69 -DA:174,52 -DA:175,17 -DA:176,17 -DA:178,0 -DA:180,73 -DA:181,19 -DA:184,14 -DA:185,14 -DA:188,13 -DA:189,13 -DA:192,0 -DA:193,0 -DA:203,14 -DA:211,7 -DA:212,7 +DA:165,19 +DA:166,19 +DA:167,73 +DA:168,0 +DA:170,73 +DA:171,4 +DA:172,4 +DA:175,69 +DA:176,52 +DA:177,17 +DA:178,17 +DA:180,0 +DA:182,73 +DA:183,19 +DA:186,14 +DA:187,14 +DA:190,13 +DA:191,13 +DA:194,0 +DA:195,0 +DA:205,14 DA:213,7 +DA:214,7 DA:215,7 DA:217,7 -DA:218,19 -DA:222,19 -DA:223,33 -DA:224,5 -DA:227,14 -DA:228,14 -DA:230,7 -DA:234,62 -DA:237,31 -DA:238,1215 -DA:242,0 -DA:245,0 -DA:251,8 -DA:254,4 -DA:259,0 -DA:260,0 -DA:269,2 -DA:270,2 -DA:273,19 -DA:274,19 -DA:277,6 -DA:278,3 +DA:219,7 +DA:220,19 +DA:224,19 +DA:225,33 +DA:226,5 +DA:229,14 +DA:230,14 +DA:232,7 +DA:236,62 +DA:239,31 +DA:240,1215 +DA:244,0 +DA:247,0 +DA:253,8 +DA:256,4 +DA:261,0 +DA:262,0 +DA:271,2 +DA:272,2 +DA:275,19 +DA:276,19 +DA:279,6 +DA:280,3 LH:89 LF:103 end_of_record @@ -411,12 +394,12 @@ LH:28 LF:38 end_of_record SF:src/projector.jl -DA:8,27958 -DA:9,48324 -DA:10,136792 +DA:8,26110 +DA:9,43857 +DA:10,122239 DA:11,2 -DA:13,253216 -DA:14,27956 +DA:13,226729 +DA:14,26108 DA:18,204 DA:19,204 DA:22,1 @@ -424,112 +407,113 @@ DA:23,1 DA:26,5484 DA:27,5484 DA:33,2058 -DA:36,17655 -DA:37,17655 -DA:40,17655 -DA:43,0 -DA:45,9382945 -DA:46,0 -DA:47,1099 -DA:49,4 +DA:34,2058 +DA:37,13044 +DA:38,13044 +DA:41,13044 +DA:44,0 +DA:46,10176154 +DA:47,0 +DA:48,315 DA:50,4 -DA:53,4 -DA:54,8 -DA:57,316637 -DA:62,3 -DA:63,6 -DA:64,3 -DA:65,2 -DA:67,2 -DA:69,1 -DA:72,1 -DA:77,2285826 -DA:78,4571652 -DA:79,4584450 -DA:80,2270269 -DA:82,4628362 -DA:83,15557 -DA:91,3733 +DA:51,4 +DA:54,4 +DA:55,8 +DA:58,138932 +DA:63,3 +DA:64,6 +DA:65,3 +DA:66,2 +DA:68,2 +DA:70,1 +DA:73,1 +DA:78,2282491 +DA:79,4564982 +DA:80,4989705 +DA:81,2268782 +DA:83,5441846 +DA:84,13709 DA:92,3733 -DA:93,7466 -DA:94,21781 -DA:95,21737 -DA:97,43562 -DA:98,3733 -DA:101,0 +DA:93,3733 +DA:94,7466 +DA:95,21522 +DA:96,21478 +DA:98,43044 +DA:99,3733 DA:102,0 -DA:104,2271830 -DA:105,2271830 -DA:108,1150636 -DA:109,1150636 -DA:115,241 +DA:103,0 +DA:105,2270063 +DA:106,2270063 +DA:109,1031468 +DA:110,1031468 DA:116,241 -DA:117,2262195 -DA:118,2257126 -DA:119,3 -DA:122,2267044 -DA:123,238 -DA:129,0 -DA:132,0 +DA:117,241 +DA:118,2262195 +DA:119,2257126 +DA:120,3 +DA:123,2267044 +DA:124,238 +DA:130,0 DA:133,0 DA:134,0 DA:135,0 -DA:138,0 +DA:136,0 DA:139,0 -DA:142,0 +DA:140,0 DA:143,0 -DA:149,8 -DA:150,4 -DA:151,2 -DA:153,4 -DA:155,4 -DA:156,2 -DA:158,6 -DA:160,8 -DA:161,4 -DA:162,0 -DA:164,4 -DA:165,2 -DA:167,2 -DA:171,4 -DA:173,4 -LH:68 -LF:82 +DA:144,0 +DA:150,8 +DA:151,4 +DA:152,2 +DA:154,4 +DA:156,4 +DA:157,2 +DA:159,6 +DA:161,8 +DA:162,4 +DA:163,0 +DA:165,4 +DA:166,2 +DA:168,2 +DA:172,4 +DA:174,4 +LH:69 +LF:83 end_of_record SF:src/subdomainmps.jl -DA:8,17446 -DA:9,17446 -DA:12,17446 -DA:13,17446 +DA:8,15878 +DA:9,15878 +DA:12,15878 +DA:13,15878 DA:17,1884 -DA:19,20799 -DA:20,3353 +DA:19,19279 +DA:20,3401 DA:22,0 DA:23,2047 -DA:25,17446 -DA:26,17446 -DA:27,17446 -DA:28,28007 -DA:29,77805 -DA:30,5355 -DA:32,145049 -DA:33,17446 +DA:25,15878 +DA:26,15878 +DA:27,15878 +DA:28,24357 +DA:29,65671 +DA:30,1483 +DA:32,122863 +DA:33,15878 DA:36,5481 DA:37,5481 DA:41,868 -DA:43,300645 -DA:44,1407521 -DA:48,300645 -DA:49,300645 -DA:50,1407521 +DA:43,269285 +DA:44,1256993 +DA:48,269285 +DA:49,269285 +DA:50,1256993 DA:51,0 -DA:53,300645 -DA:55,300645 -DA:58,11951 -DA:59,11951 -DA:60,11951 +DA:53,269285 +DA:55,269285 +DA:58,10383 +DA:59,10383 +DA:60,10383 DA:61,128 -DA:64,11823 +DA:64,10255 DA:69,0 DA:72,0 DA:75,0 @@ -540,10 +524,10 @@ DA:85,11 DA:88,11 DA:91,142 DA:94,142 -DA:97,179556 -DA:99,179556 -DA:102,17446 -DA:103,34892 +DA:97,163876 +DA:99,163876 +DA:102,15878 +DA:103,31756 DA:106,114 DA:107,114 DA:108,114 diff --git a/src/contract.jl b/src/contract.jl index 5627113..e47f160 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -89,7 +89,7 @@ end """ Project SubDomainMPS vectors to `proj` before computing all possible pairwise contractions of the elements. -The results are summed or patch-summed. +The results are summed or patch-summed if belonging to the same patch. """ function projcontract( M1::AbstractVector{SubDomainMPS}, @@ -128,126 +128,105 @@ function projcontract( return res end -""" -Contract two PartitionedMPSs MPS objects. - -At each site, the objects must share at least one site index. -""" -function contract( - M1::PartitionedMPS, - M2::PartitionedMPS; - alg="zipup", - cutoff=default_cutoff(), - maxdim=default_maxdim(), - patchorder=Index[], - parallel::Symbol=:serial, - kwargs..., -)::Union{PartitionedMPS} - M = PartitionedMPS() - if parallel == :distributed - return distribute_contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, kwargs...) - elseif parallel == :serial - return contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, kwargs...) - else - error("Symbol $(parallel) not recongnized.") - end -end - -function add_entry!( - dict::Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}, proj::Projector +# Function to add a new patch to the result patched contraction. Only if the patch is non-overlapping with any +# of the already present ones it is added, otherwise it is fused. +function add_result_patch!( + dict::Dict{Projector,Vector{Tuple{SubDomainMPS,SubDomainMPS}}}, proj::Projector ) - # Iterate over a copy of keys to avoid modifying the dict while looping. - for existing in collect(keys(dict)) - if hasoverlap(existing, proj) - fused_proj = existing | proj - # Save the current value for the overlapping key. - val = dict[existing] - # Remove the old key (this deletes its associated value). - delete!(dict, existing) + # Iterate over a copy of keys (patches) to avoid modifications while looping. + for existing_proj in collect(keys(dict)) + if hasoverlap(existing_proj, proj) + fused_proj = existing_proj | proj + # Save the subdmpss of the overlapping patch. + subdmpss = dict[existing_proj] + # Remove the old projector (this deletes also its associated subdmpss). + delete!(dict, existing_proj) # Recursively update with the fused projector. - new_key = add_entry!(dict, fused_proj) - # If new_key is already present, merge the values; otherwise, insert the saved value. - if haskey(dict, new_key) - old_val = dict[new_key] - dict[new_key] = (union(old_val[1], val[1]), union(old_val[2], val[2])) + new_proj = add_result_patch!(dict, fused_proj) + # If new_proj is already present, merge the subdmpss; otherwise, insert the saved subdmpss. + if haskey(dict, new_proj) + append!(dict[new_proj], subdmpss) else - dict[new_key] = val + dict[new_proj] = subdmpss end - return new_key + return new_proj end end - # If no overlapping key is found, then ensure proj is in the dictionary. + # If no overlapping proj is found, then ensure proj is in the dictionary (sanity passage). if !haskey(dict, proj) - dict[proj] = (Set{SubDomainMPS}(), Set{SubDomainMPS}()) + dict[proj] = Vector{Tuple{SubDomainMPS,SubDomainMPS}}() end return proj end -""" -Contract two PartitionedMPS objects. - -Existing patches `M` in the resulting PartitionedMPS will be overwritten if `overwrite=true`. -""" -function contract!( - M::PartitionedMPS, +# Preprocessing of the patches to obtain all the contraction tasks from two PartitionedMPSs +function _contraction_tasks( M1::PartitionedMPS, M2::PartitionedMPS; - alg="zipup", - cutoff=default_cutoff(), - maxdim=default_maxdim(), - patchorder=Index[], + M::PartitionedMPS=PartitionedMPS(), overwrite=true, - kwargs..., -)::Union{PartitionedMPS} - patches_to_sets = Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() - +)::Vector{Tuple{Projector,SubDomainMPS,SubDomainMPS}} + final_patches = Dict{Projector,Vector{Tuple{SubDomainMPS,SubDomainMPS}}}() + # Add a new patch only if the two subdmps are compatible (overlapping internal projected + # sites) and the new patch is non-overlapping with all the existing ones. for m1 in values(M1), m2 in values(M2) - if hasoverlap(m1.projector, m2.projector) - patch = add_entry!(patches_to_sets, _projector_after_contract(m1, m2)[1]) - if haskey(patches_to_sets, patch) - set1, set2 = patches_to_sets[patch] - push!(set1, m1) - push!(set2, m2) + tmp_prj = _projector_after_contract(m1, m2)[1] + if tmp_prj !== nothing + patch = add_result_patch!(final_patches, tmp_prj) + if haskey(final_patches, patch) + push!(final_patches[patch], (m1, m2)) else - patches_to_sets[patch] = (Set([m1]), Set([m2])) + final_patches[patch] = (m1, m2) end end end - for b1 in keys(patches_to_sets), b2 in keys(patches_to_sets) - if b1 != b2 && hasoverlap(b1, b2) + # Sanity check + for p1 in keys(final_patches), p2 in keys(final_patches) + if p1 != p2 && hasoverlap(p1, p2) error("After contraction, projectors must not overlap.") end end - # Builds tasks to parallelise - tasks = Vector{Tuple{Projector,Vector{SubDomainMPS},Vector{SubDomainMPS}}}() - for (proj, (set1, set2)) in patches_to_sets + # Flatten the result to create contraction tasks + tasks = Vector{Tuple{Projector,SubDomainMPS,SubDomainMPS}}() + for (proj, submps_pairs) in final_patches if haskey(M.data, proj) && !overwrite continue end - push!(tasks, (proj, collect(set1), collect(set2))) - end - - function process_task(task) - proj, M1_subs, M2_subs = task - return projcontract( - M1_subs, M2_subs, proj; alg, cutoff, maxdim, patchorder, kwargs... - ) + for (subdmps1, subdmps2) in submps_pairs + push!(tasks, (proj, subdmps1, subdmps2)) + end end - results = map(process_task, tasks) + return tasks +end - for res in results - if res !== nothing - append!(M, res) - end - end +""" +Contract two PartitionedMPSs MPS objects. - return M +At each site, the objects must share at least one site index. +""" +function contract( + M1::PartitionedMPS, + M2::PartitionedMPS; + alg="zipup", + cutoff=default_cutoff(), + maxdim=default_maxdim(), + patchorder=Index[], + parallel::Symbol=:serial, + kwargs..., +)::Union{PartitionedMPS} + M = PartitionedMPS() + return contract!(M, M1, M2; alg, cutoff, maxdim, patchorder, parallel, kwargs...) end -function distribute_contract!( +""" +Contract two PartitionedMPS objects. + +Existing patches `M` in the resulting PartitionedMPS will be overwritten if `overwrite=true`. +""" +function contract!( M::PartitionedMPS, M1::PartitionedMPS, M2::PartitionedMPS; @@ -256,90 +235,91 @@ function distribute_contract!( cutoff=default_cutoff(), maxdim=default_maxdim(), patchorder=Index[], + parallel::Symbol=:serial, overwrite=true, kwargs..., -)::Union{PartitionedMPS} - patches_to_sets = Dict{Projector,Tuple{Set{SubDomainMPS},Set{SubDomainMPS}}}() - - for m1 in values(M1), m2 in values(M2) - if hasoverlap(m1.projector, m2.projector) - patch = add_entry!(patches_to_sets, _projector_after_contract(m1, m2)[1]) - if haskey(patches_to_sets, patch) - set1, set2 = patches_to_sets[patch] - push!(set1, m1) - push!(set2, m2) - else - patches_to_sets[patch] = (Set([m1]), Set([m2])) - end - end - end - - for b1 in keys(patches_to_sets), b2 in keys(patches_to_sets) - if b1 != b2 && hasoverlap(b1, b2) - error("After contraction, projectors must not overlap.") - end - end +)::PartitionedMPS + # Builds contraction tasks + tasks = _contraction_tasks(M1, M2; M=M, overwrite=overwrite) - tasks = Vector{Tuple{Projector,SubDomainMPS,SubDomainMPS}}() - for (proj, (set1, set2)) in patches_to_sets - for subdmps1 in set1, subdmps2 in set2 - if haskey(M.data, proj) && !overwrite - continue - end - push!(tasks, (proj, subdmps1, subdmps2)) - end + # Helper contraction function + function contract_task(task; alg, cutoff, maxdim, kwargs...) + proj, M1_subs, M2_subs = task + return projcontract(M1_subs, M2_subs, proj; alg, cutoff, maxdim, kwargs...) end - function process_task(task_tuple; alg, cutoff, maxdim, kwargs...) - # Unpack the tuple - proj, subdmps1, subdmps2 = task_tuple - res = projcontract(subdmps1, subdmps2, proj; alg, cutoff, maxdim, kwargs...) - return (proj, res) + # Serial or distributed contraction + if parallel == :serial + contr_results = map( + task -> contract_task(task; alg, cutoff, maxdim, kwargs...), tasks + ) + elseif parallel == :distributed + contr_results = pmap( + task -> contract_task(task; alg, cutoff, maxdim, kwargs...), tasks + ) + else + error("Symbol $(parallel) not recongnized.") end - results = pmap(task -> process_task(task; alg, cutoff, maxdim, kwargs...), tasks) - valid_results = filter(x -> x[2] !== nothing, results) + # Sanity check + all(r -> r !== nothing, contr_results) || + error("Some contraction returned `nothing`. Faulty preprocessing of patches...") + ## Resum SubDomainMPSs projected on the same final patch + # Group together patches to resum patch_group = Dict{Projector,Vector{SubDomainMPS}}() - for (b, subdmps) in valid_results - if haskey(patch_group, b) - push!(patch_group[b], subdmps) + for subdmps in contr_results + if haskey(patch_group, subdmps.projector) + push!(patch_group[subdmps.projector], subdmps) else - patch_group[b] = [subdmps] + patch_group[subdmps.projector] = [subdmps] end end - patch_group_array = collect(patch_group) - - function sum_patches(group; patchorder, alg_sum, cutoff, maxdim, kwargs...) - b, subdmps_list = group - if length(subdmps_list) == 1 - return [subdmps_list[1]] + # Helper sum function + function sum_task(group; patchorder, alg_sum, cutoff, maxdim, kwargs...) + if length(group) == 1 + return [group[1]] else res = if length(patchorder) > 0 - _add_patching(subdmps_list; cutoff, maxdim, patchorder, kwargs...) + _add_patching(group; cutoff, maxdim, patchorder, kwargs...) else - [_add(subdmps_list...; alg=alg_sum, cutoff, maxdim, kwargs...)] + [_add(group...; alg=alg_sum, cutoff, maxdim, kwargs...)] end return res end end - summed_patches = pmap( - group -> sum_patches( - group; - patchorder=patchorder, - alg_sum=alg_sum, - cutoff=cutoff, - maxdim=maxdim, - kwargs..., - ), - patch_group_array, - ) + if parallel == :serial + summed_patches = map( + group -> sum_task( + group; + patchorder=patchorder, + alg_sum=alg_sum, + cutoff=cutoff, + maxdim=maxdim, + kwargs..., + ), + collect(values(patch_group)), + ) + elseif parallel == :distributed + summed_patches = pmap( + group -> sum_task( + group; + patchorder=patchorder, + alg_sum=alg_sum, + cutoff=cutoff, + maxdim=maxdim, + kwargs..., + ), + collect(values(patch_group)), + ) + end - for res in summed_patches - if res !== nothing - append!(M, vcat(res)) + # Assembling the PartitionedMPS + for subdmpss in summed_patches + if subdmpss !== nothing + append!(M, vcat(subdmpss)) end end diff --git a/src/partitionedmps.jl b/src/partitionedmps.jl index d7c67e4..9d3e223 100644 --- a/src/partitionedmps.jl +++ b/src/partitionedmps.jl @@ -23,11 +23,13 @@ PartitionedMPS(data::SubDomainMPS) = PartitionedMPS([data]) PartitionedMPS() = PartitionedMPS(SubDomainMPS[]) -projectors(obj::PartitionedMPS) = keys(obj) +projectors(obj::PartitionedMPS) = collect(keys(obj)) function Base.append!(a::PartitionedMPS, b::PartitionedMPS) - if !isdisjoint(collect(union(projectors(a), projectors(b)))) - error("Projectors are overlapping") + if !isdisjoint(vcat(projectors(a), projectors(b))) + error( + "Projectors are overlapping or identical. Resum of patches could be necessary." + ) end for (k, v) in b.data a.data[k] = v diff --git a/src/projector.jl b/src/projector.jl index 14a1309..6bb7eb4 100644 --- a/src/projector.jl +++ b/src/projector.jl @@ -30,8 +30,9 @@ end """ Constructing a projector from a single pair of index and integer. """ -Projector(singleproj::Pair{Index{T},Int}) where {T} = - Projector(Dict{Index,Int}(singleproj.first => singleproj.second)) +function Projector(singleproj::Pair{Index{T},Int}) where {T} + return Projector(Dict{Index,Int}(singleproj.first => singleproj.second)) +end function Base.hash(p::Projector, h::UInt) tmp = hash( From 6f6177f13f1e5ca6933363318b7f65f3c5806d49 Mon Sep 17 00:00:00 2001 From: Gianluca Grosso Date: Thu, 17 Apr 2025 19:55:34 +0200 Subject: [PATCH 10/11] Minor bug fixes --- benchmark/distributecontract.jl | 17 ++++++++++------- test/automul_tests.jl | 26 +++++++++++++++++++++----- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/benchmark/distributecontract.jl b/benchmark/distributecontract.jl index 4f9229c..be62883 100644 --- a/benchmark/distributecontract.jl +++ b/benchmark/distributecontract.jl @@ -34,7 +34,7 @@ using ITensors, ITensorMPS random_mpo_file = normpath(joinpath(dirname(pathof(PartitionedMPSs)), "../test/_util.jl")) Random.seed!(1234) -R = 5 +R = 10 d = 2 L = 5 @@ -42,21 +42,24 @@ sites_x = [Index(d, "Qubit,x=$x") for x in 1:R] sites_y = [Index(d, "Qubit,y=$y") for y in 1:R] sites_s = [Index(d, "Qubit,s=$s") for s in 1:R] -sites_xs = collect(Iterators.flatten(zip(sites_x, sites_s))) -sites_sy = collect(Iterators.flatten(zip(sites_s, sites_y))) +sites_xs = collect(collect.(zip(sites_x, sites_s))) +sites_sy = collect(collect.(zip(sites_s, sites_y))) -Ψ_l = ITensorMPS.convert(MPS, _random_mpo([[x] for x in sites_xs]; linkdims=L)) -Ψ_r = ITensorMPS.convert(MPS, _random_mpo([[x] for x in sites_sy]; linkdims=L)) +sites_xs_flat = collect(Iterators.flatten(sites_xs)) +sites_sy_flat = collect(Iterators.flatten(sites_sy)) + +Ψ_l = ITensorMPS.convert(MPS, _random_mpo(sites_xs; linkdims=L)) +Ψ_r = ITensorMPS.convert(MPS, _random_mpo(sites_sy; linkdims=L)) proj_lev_l = 4 proj_l = vec([ - Dict(zip(sites_xs, combo)) for + Dict(zip(sites_xs_flat, combo)) for combo in Iterators.product((1:d for _ in 1:proj_lev_l)...) ]) proj_lev_r = 6 proj_r = vec([ - Dict(zip(sites_sy, combo)) for + Dict(zip(sites_sy_flat, combo)) for combo in Iterators.product((1:d for _ in 1:proj_lev_r)...) ]) diff --git a/test/automul_tests.jl b/test/automul_tests.jl index f002450..913874c 100644 --- a/test/automul_tests.jl +++ b/test/automul_tests.jl @@ -13,7 +13,8 @@ import PartitionedMPSs: project, elemmul, automul, - default_cutoff + default_cutoff, + rearrange_siteinds import FastMPOContractions as FMPOC @@ -60,7 +61,7 @@ import FastMPOContractions as FMPOC test_points = [[rand(1:d) for __ in 1:N] for _ in 1:1000] - isapprox( + @test isapprox( [_evaluate(mps_element_prod, sites, p) for p in test_points], [_evaluate(Ψ, sites, p)^2 for p in test_points]; atol=sqrt(default_cutoff()), # default_cutoff() = 1e-25 is the contraction cutoff @@ -77,6 +78,7 @@ import FastMPOContractions as FMPOC sites_l = [Index(d, "Qubit, l=$l") for l in 1:N] sites_mn = collect(Iterators.flatten(collect.(zip(sites_m, sites_n)))) sites_nl = collect(Iterators.flatten(collect.(zip(sites_n, sites_l)))) + final_sites = collect(Iterators.flatten(collect.(zip(sites_m, sites_l)))) Ψ_l = ITensorMPS.convert(MPS, _random_mpo([[x] for x in sites_mn]; linkdims=L)) Ψ_r = ITensorMPS.convert(MPS, _random_mpo([[x] for x in sites_nl]; linkdims=L)) @@ -102,11 +104,25 @@ import FastMPOContractions as FMPOC ) mps_matmul = MPS(matmul) - naive_matmul = FMPOC.contract_mpo_mpo( - MPO(collect(Ψ_l)), MPO(collect(Ψ_r)); alg="naive" + sites_mn_vec = collect(collect.(zip(sites_m, sites_n))) + sites_nl_vec = collect(collect.(zip(sites_n, sites_l))) + + mpo_l = MPO(collect(rearrange_siteinds(Ψ_l, sites_mn_vec))) + mpo_r = MPO(collect(rearrange_siteinds(Ψ_r, sites_nl_vec))) + + naive_matmul = FMPOC.contract_mpo_mpo(mpo_l, mpo_r; alg="naive") + mps_naive_matmul = rearrange_siteinds( + ITensorMPS.convert(MPS, naive_matmul), [[x] for x in final_sites] ) - mps_naive_matmul = ITensorMPS.convert(MPS, naive_matmul) @test mps_matmul ≈ mps_naive_matmul + + test_points = [[rand(1:d) for __ in 1:(2 * N)] for _ in 1:1000] + + @test isapprox( + [_evaluate(mps_matmul, final_sites, p) for p in test_points], + [_evaluate(mps_naive_matmul, final_sites, p) for p in test_points]; + atol=sqrt(default_cutoff()), # default_cutoff() = 1e-25 is the contraction cutoff + ) end end From 4db078bad387f1020c1c47fb8559cc5d3643084d Mon Sep 17 00:00:00 2001 From: Gianluca Grosso Date: Sun, 20 Apr 2025 22:14:28 +0200 Subject: [PATCH 11/11] Implement adaptive contraction and bug fixes --- Project.toml | 5 +- coverage/lcov.info | 836 ++++++++++++++++++++------------------ src/adaptivemul.jl | 189 +++++++-- src/contract.jl | 2 +- src/subdomainmps.jl | 5 +- test/adaptivemul_tests.jl | 136 +++++++ test/runtests.jl | 1 + 7 files changed, 755 insertions(+), 419 deletions(-) create mode 100644 test/adaptivemul_tests.jl diff --git a/Project.toml b/Project.toml index 59dea03..f354f5c 100644 --- a/Project.toml +++ b/Project.toml @@ -18,13 +18,14 @@ TCIAlgorithms = {url = "https://github.com/tensor4all/TCIAlgorithms.jl.git"} [compat] Distributed = "1" EllipsisNotation = "1" -FastMPOContractions = "0.2.5" +FastMPOContractions = "0.2.8" ITensorMPS = "0.3.2" ITensors = "0.7" OrderedCollections = "1.6.3" julia = "1.6" [extras] +QuanticsGrids = "634c7f73-3e90-4749-a1bd-001b8efc642d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" TCIAlgorithms = "baf62351-2e82-41dd-9129-4f5768a618e1" TCIITensorConversion = "9f0aa9f4-9415-4e6a-8795-331ebf40aa04" @@ -32,4 +33,4 @@ TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random", "TCIAlgorithms", "TensorCrossInterpolation", "TCIITensorConversion"] +test = ["Test", "Random", "TCIAlgorithms", "TensorCrossInterpolation", "TCIITensorConversion", "QuanticsGrids"] diff --git a/coverage/lcov.info b/coverage/lcov.info index 2eaeeca..e51b28d 100644 --- a/coverage/lcov.info +++ b/coverage/lcov.info @@ -1,46 +1,107 @@ SF:src/PartitionedMPSs.jl DA:16,0 -DA:17,47 +DA:17,52 LH:1 LF:2 end_of_record SF:src/adaptivemul.jl -DA:8,0 -DA:9,0 -DA:10,0 -DA:11,0 -DA:12,0 -DA:14,0 -DA:15,0 -DA:17,0 -DA:21,0 -DA:23,0 +DA:8,896 +DA:9,896 +DA:10,896 +DA:11,14848 +DA:12,14848 +DA:14,14848 +DA:15,14848 +DA:17,896 +DA:21,336 +DA:23,336 DA:24,0 -DA:26,0 +DA:26,336 DA:29,0 -DA:36,0 -DA:37,0 -DA:38,0 -DA:39,0 -DA:48,0 -DA:57,0 +DA:36,1680 +DA:37,560 +DA:38,560 +DA:39,1680 +DA:40,0 +DA:42,560 +DA:46,2 +DA:49,2 +DA:52,50 +DA:53,1968 +DA:54,1632 +DA:55,336 +DA:56,336 +DA:57,336 DA:59,0 -DA:60,0 -DA:61,0 -DA:62,0 -DA:64,0 -DA:65,0 +DA:62,1678 +DA:65,26 +DA:66,320 DA:67,0 -DA:69,0 -DA:73,0 -DA:75,0 -DA:76,0 -DA:77,0 -DA:79,0 +DA:69,344 +DA:72,2 +DA:74,4 +DA:75,24 +DA:76,24 +DA:78,672 +DA:79,336 DA:80,0 -DA:82,0 -LH:0 -LF:34 +DA:82,336 +DA:84,336 +DA:85,24 +DA:86,46 +DA:88,2 +DA:94,4 +DA:103,912 +DA:106,7 +DA:107,5 +DA:108,136 +DA:109,136 +DA:110,696 +DA:111,560 +DA:112,560 +DA:114,560 +DA:117,1120 +DA:122,560 +DA:123,560 +DA:125,448 +DA:128,224 +DA:129,112 +DA:130,0 +DA:134,0 +DA:135,0 +DA:136,0 +DA:138,112 +DA:139,112 +DA:140,112 +DA:142,224 +DA:144,448 +DA:146,448 +DA:148,112 +DA:154,336 +DA:156,112 +DA:160,448 +DA:161,560 +DA:164,136 +DA:165,56 +DA:167,136 +DA:168,5 +DA:170,2 +DA:173,2 +DA:182,4 +DA:192,2 +DA:195,2 +DA:198,2 +DA:201,2 +DA:202,4 +DA:204,160 +DA:206,80 +DA:207,0 +DA:209,80 +DA:217,80 +DA:219,158 +DA:221,2 +LH:84 +LF:95 end_of_record SF:src/automul.jl DA:16,2 @@ -78,84 +139,84 @@ DA:101,10 DA:102,10 DA:103,0 DA:105,10 -DA:106,7 -DA:107,7 -DA:109,3 -DA:110,3 +DA:106,4 +DA:107,4 +DA:109,6 +DA:110,6 DA:113,10 DA:114,1 LH:40 LF:41 end_of_record SF:src/bak/conversion.jl -DA:4,3 -DA:5,3 -DA:7,3 -DA:8,3 -DA:9,3 -DA:11,3 -DA:20,3 -DA:21,6 -DA:30,9 -DA:32,3 -DA:41,3 -DA:42,3 -DA:43,12 -DA:44,24 -DA:45,6 -DA:47,36 -DA:48,21 -DA:50,3 -DA:56,1 -DA:57,1 +DA:4,83 +DA:5,83 +DA:7,83 +DA:8,83 +DA:9,83 +DA:11,83 +DA:20,83 +DA:21,1446 +DA:30,2809 +DA:32,83 +DA:41,83 +DA:42,83 +DA:43,1612 +DA:44,3224 +DA:45,502 +DA:47,4836 +DA:48,3141 +DA:50,83 +DA:56,3 +DA:57,3 LH:20 LF:20 end_of_record SF:src/contract.jl -DA:9,630 -DA:13,315 -DA:14,16 -DA:16,598 -DA:18,299 -DA:19,299 -DA:22,299 -DA:26,1700 -DA:27,1700 -DA:28,1700 -DA:30,1700 -DA:32,1700 -DA:33,824 -DA:36,876 -DA:37,14001 -DA:39,876 -DA:40,17237 -DA:41,1669 -DA:43,17237 -DA:44,2459 -DA:46,17237 -DA:48,876 -DA:52,313 -DA:53,626 -DA:54,10762 +DA:9,1782 +DA:13,891 +DA:14,0 +DA:16,1782 +DA:18,891 +DA:19,891 +DA:22,891 +DA:26,4868 +DA:27,4868 +DA:28,4868 +DA:30,4868 +DA:32,4868 +DA:33,2120 +DA:36,2748 +DA:37,43953 +DA:39,2748 +DA:40,78037 +DA:41,6469 +DA:43,78037 +DA:44,7019 +DA:46,78037 +DA:48,2748 +DA:52,329 +DA:53,658 +DA:54,10954 DA:55,0 -DA:57,2838 -DA:58,313 -DA:64,818 -DA:74,409 -DA:75,409 -DA:76,754 +DA:57,2934 +DA:58,329 +DA:64,850 +DA:74,425 +DA:75,425 +DA:76,786 DA:77,96 -DA:80,610 -DA:82,313 +DA:80,658 +DA:82,329 DA:83,0 -DA:86,313 -DA:87,313 +DA:86,329 +DA:87,329 DA:94,16 DA:105,8 DA:107,8 DA:108,128 DA:109,128 -DA:110,16 +DA:110,32 DA:112,128 DA:114,8 DA:115,0 @@ -165,106 +226,106 @@ DA:122,8 DA:123,0 DA:125,16 DA:128,8 -DA:133,512 -DA:137,512 -DA:138,9849 -DA:139,232 -DA:141,232 -DA:143,232 -DA:145,232 -DA:147,232 -DA:148,464 +DA:133,1184 +DA:137,1184 +DA:138,17955 +DA:139,552 +DA:141,552 +DA:143,552 +DA:145,552 +DA:147,552 +DA:148,1104 DA:150,0 -DA:152,232 -DA:154,9617 -DA:156,280 -DA:157,280 -DA:159,280 -DA:163,8 -DA:169,4 -DA:172,32 -DA:173,1368 -DA:174,1088 -DA:175,280 -DA:176,280 -DA:177,280 +DA:152,552 +DA:154,17403 +DA:156,632 +DA:157,632 +DA:159,632 +DA:163,10 +DA:169,5 +DA:172,41 +DA:173,1416 +DA:174,1120 +DA:175,296 +DA:176,296 +DA:177,296 DA:179,0 -DA:182,1112 -DA:185,52 -DA:186,1120 +DA:182,1151 +DA:185,61 +DA:186,1184 DA:187,0 -DA:189,1168 -DA:192,4 -DA:193,8 -DA:194,48 +DA:189,1240 +DA:192,5 +DA:193,10 +DA:194,56 DA:195,0 -DA:197,48 -DA:198,280 -DA:199,280 -DA:200,92 -DA:202,4 -DA:210,8 -DA:220,4 -DA:221,4 -DA:229,8 -DA:243,4 -DA:246,564 -DA:247,280 -DA:248,280 -DA:252,4 -DA:253,4 -DA:254,280 +DA:197,56 +DA:198,296 +DA:199,296 +DA:200,107 +DA:202,5 +DA:210,10 +DA:220,5 +DA:221,5 +DA:229,10 +DA:243,5 +DA:246,597 +DA:247,296 +DA:248,296 +DA:252,5 +DA:253,5 +DA:254,296 DA:256,0 DA:257,0 DA:261,0 -DA:265,284 -DA:270,4 -DA:271,4 -DA:272,280 -DA:273,232 -DA:275,48 -DA:277,280 -DA:280,100 -DA:281,48 +DA:265,301 +DA:270,5 +DA:271,5 +DA:272,296 +DA:273,240 +DA:275,56 +DA:277,296 +DA:280,117 +DA:281,56 DA:282,8 -DA:284,40 +DA:284,48 DA:285,4 -DA:287,76 -DA:289,40 -DA:293,4 -DA:294,8 -DA:295,48 +DA:287,92 +DA:289,48 +DA:293,5 +DA:294,10 +DA:295,56 DA:305,0 DA:306,0 -DA:320,4 -DA:321,48 -DA:322,96 -DA:324,48 -DA:326,4 -LH:115 +DA:320,5 +DA:321,56 +DA:322,112 +DA:324,56 +DA:326,5 +LH:114 LF:129 end_of_record SF:src/partitionedmps.jl -DA:8,172 -DA:9,172 -DA:10,172 -DA:11,1713 -DA:12,3355 -DA:13,174 -DA:15,1881 -DA:18,170 +DA:8,187 +DA:9,187 +DA:10,187 +DA:11,1891 +DA:12,3705 +DA:13,189 +DA:15,2074 +DA:18,185 DA:22,20 -DA:24,24 -DA:26,245 -DA:28,67 -DA:29,134 +DA:24,25 +DA:26,276 +DA:28,75 +DA:29,150 DA:30,0 -DA:34,134 -DA:35,1141 -DA:36,2215 -DA:37,67 -DA:40,67 -DA:41,67 +DA:34,150 +DA:35,1149 +DA:36,2223 +DA:37,75 +DA:40,75 +DA:41,75 DA:48,19 DA:49,19 DA:52,19 @@ -277,10 +338,10 @@ DA:70,1 DA:71,1 DA:74,1 DA:75,2 -DA:81,494 -DA:82,494 -DA:88,69 -DA:89,69 +DA:81,510 +DA:82,510 +DA:88,128 +DA:89,128 DA:96,3 DA:97,3 DA:100,4 @@ -333,9 +394,9 @@ DA:226,5 DA:229,14 DA:230,14 DA:232,7 -DA:236,62 -DA:239,31 -DA:240,1215 +DA:236,70 +DA:239,35 +DA:240,1375 DA:244,0 DA:247,0 DA:253,8 @@ -352,24 +413,24 @@ LH:89 LF:103 end_of_record SF:src/patching.jl -DA:7,4088 -DA:14,2044 +DA:7,4296 +DA:14,2148 DA:15,0 -DA:20,2044 -DA:23,2044 -DA:27,2040 -DA:29,1020 -DA:31,1020 -DA:32,1020 -DA:33,2040 -DA:34,2040 -DA:42,3060 -DA:44,1020 -DA:50,1021 -DA:51,19447 -DA:52,1021 +DA:20,2148 +DA:23,2148 +DA:27,2064 +DA:29,1032 +DA:31,1032 +DA:32,1032 +DA:33,2064 +DA:34,2064 +DA:42,3096 +DA:44,1032 +DA:50,1145 +DA:51,20931 +DA:52,1145 DA:53,0 -DA:55,1021 +DA:55,1145 DA:62,0 DA:69,0 DA:72,0 @@ -394,32 +455,32 @@ LH:28 LF:38 end_of_record SF:src/projector.jl -DA:8,26110 -DA:9,43857 -DA:10,122239 +DA:8,38466 +DA:9,67050 +DA:10,180176 DA:11,2 -DA:13,226729 -DA:14,26108 -DA:18,204 -DA:19,204 +DA:13,331766 +DA:14,38464 +DA:18,216 +DA:19,216 DA:22,1 DA:23,1 -DA:26,5484 -DA:27,5484 -DA:33,2058 -DA:34,2058 -DA:37,13044 -DA:38,13044 -DA:41,13044 +DA:26,6520 +DA:27,6520 +DA:33,2306 +DA:34,2306 +DA:37,18359 +DA:38,18359 +DA:41,18359 DA:44,0 -DA:46,10176154 +DA:46,11256883 DA:47,0 -DA:48,315 +DA:48,331 DA:50,4 DA:51,4 DA:54,4 DA:55,8 -DA:58,138932 +DA:58,147814 DA:63,3 DA:64,6 DA:65,3 @@ -427,32 +488,32 @@ DA:66,2 DA:68,2 DA:70,1 DA:73,1 -DA:78,2282491 -DA:79,4564982 -DA:80,4989705 -DA:81,2268782 -DA:83,5441846 -DA:84,13709 -DA:92,3733 -DA:93,3733 -DA:94,7466 -DA:95,21522 -DA:96,21478 -DA:98,43044 -DA:99,3733 +DA:78,2311825 +DA:79,4623650 +DA:80,5499223 +DA:81,2290912 +DA:83,6416622 +DA:84,20913 +DA:92,5637 +DA:93,5637 +DA:94,11274 +DA:95,30730 +DA:96,30667 +DA:98,61460 +DA:99,5637 DA:102,0 DA:103,0 -DA:105,2270063 -DA:106,2270063 -DA:109,1031468 -DA:110,1031468 -DA:116,241 -DA:117,241 -DA:118,2262195 -DA:119,2257126 +DA:105,2309196 +DA:106,2309196 +DA:109,1652364 +DA:110,1652364 +DA:116,268 +DA:117,268 +DA:118,2275239 +DA:119,2269838 DA:120,3 -DA:123,2267044 -DA:124,238 +DA:123,2280394 +DA:124,265 DA:130,0 DA:133,0 DA:134,0 @@ -481,161 +542,160 @@ LH:69 LF:83 end_of_record SF:src/subdomainmps.jl -DA:8,15878 -DA:9,15878 -DA:12,15878 -DA:13,15878 -DA:17,1884 -DA:19,19279 -DA:20,3401 +DA:8,20830 +DA:9,20830 +DA:12,20830 +DA:13,20830 +DA:17,3868 +DA:19,30567 +DA:20,9737 DA:22,0 -DA:23,2047 -DA:25,15878 -DA:26,15878 -DA:27,15878 -DA:28,24357 -DA:29,65671 -DA:30,1483 -DA:32,122863 -DA:33,15878 -DA:36,5481 -DA:37,5481 +DA:23,2711 +DA:25,20830 +DA:26,20830 +DA:27,20830 +DA:28,32903 +DA:29,82275 +DA:30,5475 +DA:32,152477 +DA:33,20830 +DA:36,6517 +DA:37,6517 DA:41,868 -DA:43,269285 -DA:44,1256993 -DA:48,269285 -DA:49,269285 -DA:50,1256993 +DA:43,398333 +DA:44,1884657 +DA:48,398333 +DA:49,398333 +DA:50,1884657 DA:51,0 -DA:53,269285 -DA:55,269285 -DA:58,10383 -DA:59,10383 -DA:60,10383 -DA:61,128 -DA:64,10255 -DA:69,0 -DA:72,0 -DA:75,0 -DA:78,0 -DA:81,3411 -DA:82,3411 -DA:85,11 -DA:88,11 -DA:91,142 -DA:94,142 -DA:97,163876 -DA:99,163876 -DA:102,15878 -DA:103,31756 +DA:53,398333 +DA:55,398333 +DA:58,14219 +DA:59,14219 +DA:60,128 +DA:63,14091 +DA:68,0 +DA:71,0 +DA:74,0 +DA:77,0 +DA:80,3683 +DA:81,3683 +DA:84,11 +DA:87,11 +DA:90,154 +DA:93,154 +DA:96,235844 +DA:98,235844 +DA:101,20830 +DA:102,41660 +DA:105,114 DA:106,114 DA:107,114 -DA:108,114 +DA:111,0 DA:112,0 -DA:113,0 -DA:116,10 -DA:117,4 +DA:115,10 +DA:116,4 +DA:122,0 DA:123,0 DA:124,0 -DA:125,0 -DA:127,0 -DA:130,0 -DA:136,6 -DA:137,3 -DA:140,34474 -DA:141,34474 -DA:144,6456 -DA:150,2088 -DA:151,4176 -DA:153,2088 -DA:154,2088 -DA:155,2088 -DA:156,2088 -DA:159,6614 -DA:160,3307 -DA:161,1219 -DA:162,2088 +DA:126,0 +DA:129,0 +DA:135,6 +DA:136,3 +DA:139,156074 +DA:140,156074 +DA:143,8104 +DA:149,2200 +DA:150,4400 +DA:152,2200 +DA:153,2200 +DA:154,2200 +DA:155,2200 +DA:158,7158 +DA:159,3579 +DA:160,1379 +DA:161,2200 +DA:162,0 DA:163,0 -DA:164,0 -DA:168,0 -DA:169,2088 -DA:170,4368 -DA:171,2280 -DA:175,2088 -DA:176,4368 -DA:177,2088 -DA:179,0 -DA:183,2438 -DA:186,1219 -DA:189,6614 -DA:192,3307 +DA:167,0 +DA:168,2200 +DA:169,5904 +DA:170,3704 +DA:174,2200 +DA:175,5904 +DA:176,2200 +DA:178,0 +DA:182,2758 +DA:185,1379 +DA:188,7158 +DA:191,3579 +DA:197,16 DA:198,16 -DA:199,16 +DA:201,77 DA:202,77 -DA:203,77 +DA:205,0 DA:206,0 -DA:207,0 -DA:210,2974 -DA:211,1487 +DA:209,3294 +DA:210,1647 +DA:213,93 DA:214,93 -DA:215,93 -DA:216,0 +DA:215,0 +DA:217,93 DA:218,93 -DA:219,93 +DA:221,93 DA:222,93 -DA:223,93 -DA:226,188 +DA:225,188 +DA:228,94 DA:229,94 -DA:230,94 +DA:230,70 DA:231,70 DA:232,70 -DA:233,70 +DA:234,94 DA:235,94 -DA:236,94 -DA:237,70 +DA:236,70 +DA:237,32 DA:238,32 -DA:239,32 +DA:239,8 DA:240,8 -DA:241,8 -DA:244,70 -DA:246,94 +DA:243,70 +DA:245,94 +DA:248,0 DA:249,0 -DA:250,0 -DA:253,188 -DA:256,94 +DA:252,188 +DA:255,94 +DA:258,1 DA:259,1 DA:260,1 -DA:261,1 -DA:263,1 +DA:262,1 +DA:266,1 DA:267,1 -DA:268,1 -DA:269,3 -DA:270,0 -DA:272,3 -DA:274,1 -DA:277,42 +DA:268,3 +DA:269,0 +DA:271,3 +DA:273,1 +DA:276,42 +DA:279,42 DA:280,42 -DA:281,42 -DA:282,371 +DA:281,371 +DA:282,48 DA:283,48 DA:284,48 -DA:285,48 -DA:287,48 -DA:289,48 -DA:290,700 -DA:292,42 -DA:294,83 +DA:286,48 +DA:288,48 +DA:289,700 +DA:291,42 +DA:293,83 +DA:294,201 DA:295,201 -DA:296,201 -DA:297,361 -DA:298,42 +DA:296,361 +DA:297,42 +DA:300,1 DA:301,1 DA:302,1 -DA:303,1 +DA:305,0 DA:306,0 -DA:307,0 -LH:128 -LF:153 +LH:127 +LF:152 end_of_record SF:src/util.jl DA:2,0 @@ -672,32 +732,32 @@ DA:66,73 DA:67,146 DA:68,219 DA:69,73 -DA:74,114 -DA:75,114 -DA:77,114 -DA:80,114 -DA:81,228 -DA:82,114 -DA:83,114 -DA:84,1452 -DA:85,2258 -DA:86,332 -DA:88,3852 -DA:89,1926 -DA:90,1926 -DA:91,1926 -DA:92,1926 -DA:93,1926 -DA:94,1926 -DA:95,1926 -DA:96,2258 -DA:98,1452 -DA:99,1338 -DA:101,1566 -DA:103,1452 -DA:104,2790 -DA:105,114 -DA:106,114 +DA:74,118 +DA:75,118 +DA:77,118 +DA:80,118 +DA:81,236 +DA:82,118 +DA:83,118 +DA:84,1532 +DA:85,2358 +DA:86,362 +DA:88,3992 +DA:89,1996 +DA:90,1996 +DA:91,1996 +DA:92,1996 +DA:93,1996 +DA:94,1996 +DA:95,1996 +DA:96,2358 +DA:98,1532 +DA:99,1414 +DA:101,1650 +DA:103,1532 +DA:104,2946 +DA:105,118 +DA:106,118 DA:110,11 DA:121,10 DA:124,5 diff --git a/src/adaptivemul.jl b/src/adaptivemul.jl index f77c596..717908d 100644 --- a/src/adaptivemul.jl +++ b/src/adaptivemul.jl @@ -34,49 +34,188 @@ Project the LazyContraction object to `prj` before evaluating it. This may result in projecting the external indices of `a` and `b`. """ function project(obj::LazyContraction, prj::Projector; kwargs...)::LazyContraction - new_a = project(obj.a, a.projector & prj; kwargs...) - new_b = project(obj.b, b.projector & prj; kwargs...) + new_a = project(obj.a, prj; kwargs...) + new_b = project(obj.b, prj; kwargs...) + if isnothing(new_a) || isnothing(new_b) + error("New projector is not compatible with SubDomainMPSs projectors.") + end return LazyContraction(new_a, new_b) end +# Preprocessing of the patches to obtain all the contraction tasks from two PartitionedMPSs +function _adaptivecontraction_tasks( + M1::PartitionedMPS, M2::PartitionedMPS +)::Dict{Projector,Vector{Union{SubDomainMPS,LazyContraction}}} + final_patches = Dict{Projector,Vector{Tuple{SubDomainMPS,SubDomainMPS}}}() + # Add a new patch only if the two subdmps are compatible (overlapping internal projected + # sites) and the new patch is non-overlapping with all the existing ones. + for m1 in values(M1), m2 in values(M2) + tmp_prj = _projector_after_contract(m1, m2)[1] + if tmp_prj !== nothing + patch = add_result_patch!(final_patches, tmp_prj) + if haskey(final_patches, patch) + push!(final_patches[patch], (m1, m2)) + else + final_patches[patch] = (m1, m2) + end + end + end + + # Sanity check + for p1 in keys(final_patches), p2 in keys(final_patches) + if p1 != p2 && hasoverlap(p1, p2) + error("After contraction, projectors must not overlap.") + end + end + + # Transform the SubDomainMPS pairs in LazyContraction wrappers + tasks = Dict{Projector,Vector{Union{SubDomainMPS,LazyContraction}}}() + + for (proj, subdmps_pair) in final_patches + resultvec = Union{SubDomainMPS,LazyContraction}[] + for pair in subdmps_pair + # Trim projectors to produce only non-overlapping patches + lc = project(lazycontraction(pair...), proj) + if lc === nothing + @warn "LazyContraction == nothing. Faulty patch preprocessing..." proj pair + else + push!(resultvec, lc) + end + end + tasks[proj] = resultvec + end + + return tasks +end + +# Performs the patched contraction of two PartitionedMPS +# For each compatible pair of patches the contraction is attempted and split in smaller patches +# if the result exceeds the fixed bond dimension. +function patch_contract!( + patches::Dict{Projector,Vector{Union{SubDomainMPS,LazyContraction}}}, + pordering::AbstractVector{Index{IndsT}}, + maxdim, + cutoff; + alg="fit", + kwargs..., +) where {IndsT} + # A small helper + has_lazy() = any(any(lc -> lc isa LazyContraction, v) for v in values(patches)) + + # Keep iterating until no LazyContraction remains + while has_lazy() + for prj in collect(keys(patches)) + blockvec = patches[prj] + i = 1 + while i <= length(blockvec) + m = blockvec[i] + if m isa LazyContraction + # Attempt the actual contraction + contracted = contract( + m.a, m.b, ; alg=alg, cutoff=cutoff, maxdim=maxdim, kwargs... + ) + isnothing(contracted) && error( + "Some contractions failed. Double check the patch ordering..." + ) + + # Check the bond dimension of result + max_bdim = maxbonddim(contracted) + if max_bdim < maxdim + # Good: replace the lazy contraction with final SubDomainMPS + blockvec[i] = contracted + else + # Too large => we must expand the projector + nextprjidx = _next_projindex(m.projector, pordering) + if nextprjidx === nothing + @warn( + "Cannot expand further; bond dimension still exceeds maxdim." + ) + # Keep it anyway + blockvec[i] = contracted + i += 1 + continue + else + popat!(blockvec, i) + d = ITensors.dim(nextprjidx) + for val in 1:d + # Construct a new projector that includes (nextprjidx => val) + new_prj = m.projector & Projector(nextprjidx => val) + + new_m = project(m, new_prj) + # Add a new lazy contraction to patches[new_prj] + push!( + get!( + () -> Vector{Union{SubDomainMPS,LazyContraction}}(), + patches, + new_prj, + ), + new_m, + ) + end + # don't increment i, because we removed the old lazy contraction + continue + end + end + end + i += 1 + end + + # If the current patch ended up with an empty vector, remove it + if isempty(blockvec) + delete!(patches, prj) + end + end + end + + @assert !has_lazy() "Some LazyContraction are still present. Something went wrong..." + + # Check that the final projectors are not overlapping + return isdisjoint(collect(keys(patches))) || error("Overlapping projectors") +end + """ -Perform contruction of two PartitionedMPS objects. +Perform contraction of two PartitionedMPS objects. + +The resulting patches after the contraction are patch-added if projected on the same final patch. -The SubDomainMPS objects of each PartitionedMPS do not overlap with each other. -This makes the algorithm much simpler """ function adaptivecontract( a::PartitionedMPS, b::PartitionedMPS, - pordering::AbstractVector{Index}=Index[]; + pordering::AbstractVector{Index{IndsT}}=Index{IndsT}[]; alg="fit", + alg_sum="fit", cutoff=default_cutoff(), maxdim=default_maxdim(), kwargs..., -) - patches = Dict{Projector,Vector{Union{SubDomainMPS,LazyContraction}}}() - - for x in values(a), y in values(b) # FIXME: Naive loop over O(N^2) pairs - xy = lazycontraction(x, y) - if xy === nothing - continue - end - if haskey(patches, xy.projector) - push!(patches[xy.projector], xy) - else - patches[xy.projector] = [xy] - end - end +) where {IndsT} + patches = _adaptivecontraction_tasks(a, b) # Check no overlapping projectors. - # This should be prohibited by the fact that the blocks in each SubDomainMPS obejct do not overlap. isdisjoint(collect(keys(patches))) || error("Overlapping projectors") + # Perform the iterative patch contraction + patch_contract!(patches, pordering, maxdim, cutoff; alg=alg, kwargs...) + + # Resum SubDomainMPS on the same patch result_blocks = SubDomainMPS[] - for (p, muls) in patches - subdmps = [contract(m.a, m.b; alg, cutoff, maxdim, kwargs...) for m in muls] - #patches[p] = +(subdmps...; alg="fit", cutoff, maxdim) - push!(result_blocks, +(subdmps...; alg="fit", cutoff, maxdim)) + for (prj, blockvec) in patches + # Each entry in blockvec is now guaranteed to be a SubDomainMPS + subdmps_list = Vector{SubDomainMPS}(blockvec) + + if length(subdmps_list) == 1 + push!(result_blocks, subdmps_list[1]) + else + patch_sum = _add_patching( + subdmps_list; + alg=alg_sum, + cutoff=cutoff, + maxdim=maxdim, + patchorder=pordering, + ) + + append!(result_blocks, patch_sum) + end end return PartitionedMPS(result_blocks) diff --git a/src/contract.jl b/src/contract.jl index e47f160..2aeefcf 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -195,7 +195,7 @@ function _contraction_tasks( continue end for (subdmps1, subdmps2) in submps_pairs - push!(tasks, (proj, subdmps1, subdmps2)) + push!(tasks, (proj, project(subdmps1, proj), project(subdmps2, proj))) end end diff --git a/src/subdomainmps.jl b/src/subdomainmps.jl index 7162545..6265377 100644 --- a/src/subdomainmps.jl +++ b/src/subdomainmps.jl @@ -56,13 +56,12 @@ function project(tensor::ITensor, projector::Projector) end function project(projΨ::SubDomainMPS, projector::Projector)::Union{Nothing,SubDomainMPS} - newprj = projector & projΨ.projector - if newprj === nothing + if !hasoverlap(projector, projΨ.projector) return nothing end return SubDomainMPS( - MPS([project(projΨ.data[n], newprj) for n in 1:length(projΨ.data)]), newprj + MPS([project(projΨ.data[n], projector) for n in 1:length(projΨ.data)]), projector ) end diff --git a/test/adaptivemul_tests.jl b/test/adaptivemul_tests.jl new file mode 100644 index 0000000..9f37130 --- /dev/null +++ b/test/adaptivemul_tests.jl @@ -0,0 +1,136 @@ +import PartitionedMPSs: + PartitionedMPSs, + PartitionedMPS, + SubDomainMPS, + project, + adaptivecontract, + contract, + Projector +import FastMPOContractions as FMPOC +import QuanticsGrids as QG +import TensorCrossInterpolation as TCI +import TCIAlgorithms as TCIA + +using ITensors, ITensorMPS + +Random.seed!(1234) + +@testset "adaptivemul.jl" begin + @testset "adaptivecontract" begin + R = 8 + L = 5 + d = 2 + + tol = 1e-5 + + sites_x = [Index(d, "Qubit,x=$n") for n in 1:R] + sites_y = [Index(d, "Qubit,y=$n") for n in 1:R] + sites_s = [Index(d, "Qubit,s=$n") for n in 1:R] + + sites_l = collect(collect.(zip(sites_x, sites_s))) + sites_r = collect(collect.(zip(sites_s, sites_y))) + pordering = final_sites = collect(Iterators.flatten(zip(sites_x, sites_y))) + + mpo_l = _random_mpo(sites_l; linkdims=L) + mpo_r = _random_mpo(sites_r; linkdims=L) + + proj_lev_l = 3 + proj_l = vec([ + Dict(zip(collect(Iterators.flatten(sites_l)), combo)) for + combo in Iterators.product((1:d for _ in 1:proj_lev_l)...) + ]) + + proj_lev_r = 2 + proj_r = vec([ + Dict(zip(collect(Iterators.flatten(sites_r)), combo)) for + combo in Iterators.product((1:d for _ in 1:proj_lev_r)...) + ]) + + partΨ_l = PartitionedMPS(project.(Ref(MPS(collect(mpo_l))), proj_l)) + partΨ_r = PartitionedMPS(project.(Ref(MPS(collect(mpo_r))), proj_r)) + + part_normal = PartitionedMPSs.contract(partΨ_l, partΨ_r; cutoff=tol^2, alg="fit") + part_adaptive = adaptivecontract( + partΨ_l, partΨ_r, pordering; cutoff=tol^2, maxdim=23 + ) + + @test MPS(part_adaptive) ≈ MPS(part_normal) + + naive_mpo = FMPOC.contract_mpo_mpo(mpo_l, mpo_r; cutoff=tol^2, alg="naive") + + @test MPS(part_adaptive) ≈ MPS(collect(naive_mpo)) + end + + @testset "2D Gaussians" begin + + # Integrand function + gaussian(x, y) = exp(-1.0 * (x^2 + y^2)) + + # Analytic solution + analyticIntegral(x, y) = sqrt(π / 2) * exp(-1.0 * (x^2 + y^2)) + + # Function parameters + D = 2 + x_0 = 10.0 + + # Simulation parameters + R = 20 + unfoldingscheme = :fused + mb = 25 + tol = 1e-9 + + localdims = fill(2^D, R) + sitedims = fill([2, 2], R) + + grid = QG.DiscretizedGrid{D}( + R, Tuple(fill(-x_0, D)), Tuple(fill(x_0, D)); unfoldingscheme=unfoldingscheme + ) + q_gauss = x -> gaussian(QG.quantics_to_origcoord(grid, x)...) + patch_ordering = TCIA.PatchOrdering(collect(1:R)) + + gauss_patch = reshape( + TCIA.adaptiveinterpolate( + TCIA.makeprojectable(Float64, q_gauss, localdims), + patch_ordering; + verbosity=0, + maxbonddim=mb, + tolerance=tol, + ), + sitedims, + ) + + sites_x = [Index(2, "Qubit,x=$n") for n in 1:R] + sites_y = [Index(2, "Qubit,y=$n") for n in 1:R] + sites_s = [Index(2, "Qubit,s=$n") for n in 1:R] + + sites_l = collect(collect.(zip(sites_x, sites_s))) + sites_r = collect(collect.(zip(sites_s, sites_y))) + pordering = final_sites = collect(Iterators.flatten(zip(sites_x, sites_y))) + + part_mps_l = PartitionedMPS(gauss_patch, sites_l) + part_mps_r = PartitionedMPS(gauss_patch, sites_r) + + part_adaptive = adaptivecontract( + part_mps_l, part_mps_r, pordering; cutoff=tol^2, maxdim=mb + ) + + adaptive_mps = PartitionedMPSs.rearrange_siteinds( + MPS(part_adaptive), [[x] for x in final_sites] + ) + + N_err = 1000 + points = [(rand() * x_0 - x_0 / 2, rand() * x_0 - x_0 / 2) for _ in 1:N_err] + quantics_fused_points = QG.origcoord_to_quantics.(Ref(grid), points) + quantics_points = [ + QG.interleave_dimensions(QG.unfuse_dimensions(p, D)...) for + p in quantics_fused_points + ] + + adaptive_points = [ + (2x_0 / 2^R) * _evaluate(adaptive_mps, final_sites, p) for p in quantics_points + ] + analytic_points = [analyticIntegral(p...) for p in points] + + @test all(isapprox.(analytic_points, adaptive_points; atol=sqrt(tol))) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 11f900e..1df482d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,3 +14,4 @@ include("patching_tests.jl") include("util_tests.jl") include("automul_tests.jl") include("bak/conversion_tests.jl") +include("adaptivemul_tests.jl")