diff --git a/.gitignore b/.gitignore index b02ba6e..802a2c5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ +./samples/ Manifest.toml -samples/ \ No newline at end of file +.DS_Store \ No newline at end of file diff --git a/Project.toml b/Project.toml index bccfef1..0579f89 100644 --- a/Project.toml +++ b/Project.toml @@ -4,19 +4,45 @@ authors = ["Ryo Watanabe "] version = "0.1.0" [deps] +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +GraphTikZ = "cef0280d-a2bf-4776-a511-cf6253a7debc" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7" +ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" +NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" +QuanticsGrids = "634c7f73-3e90-4749-a1bd-001b8efc642d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" SimpleTensorNetworks = "3075f829-f72e-4896-a859-7fe0a9cabb9b" +SparseIR = "4fe2279e-80f0-4adb-8463-ee114ff56b7d" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +T4ARegistrator = "52e3f0f8-52ab-4798-a033-7b9fb81451b4" TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" [compat] +Combinatorics = "1.0.3" DataGraphs = "0.2.5" +DataStructures = "0.18.22" +Distributions = "0.25.120" +Flux = "0.16.4" +GraphTikZ = "0.1.0" Graphs = "1.12.0" +ITensorNetworks = "0.14.1" +ITensors = "0.9.11" +NPZ = "0.4.3" NamedGraphs = "0.6.4" +QuanticsGrids = "0.6.0" Random = "1.10" +Revise = "3.9.0" SimpleTensorNetworks = "0.1.0" +SparseIR = "1.1.4" +Statistics = "1.11.1" +T4ARegistrator = "0.2.1" TensorCrossInterpolation = "0.9.13" [extras] diff --git a/samples/1d_tightbinding_model.jl b/samples/1d_tightbinding_model.jl new file mode 100644 index 0000000..ef3476a --- /dev/null +++ b/samples/1d_tightbinding_model.jl @@ -0,0 +1,65 @@ +using Random +using LinearAlgebra +using TreeTCI: crossinterpolate, crossinterpolate_with_structuralsearch, crossinterpolate_with_3site_swapping +using NamedGraphs: NamedGraph, add_edge!, edges, src, dst +using QuanticsGrids +const QG = QuanticsGrids +using SparseIR +using ITensorNetworks +using TreeTCI: ttnopt +const ITN = ITensorNetworks +using ITensors +using NPZ +include("utils.jl") +include("graphs.jl") + + +ε(k) = 2*cos(k) + cos(5*k) + 2*cos(20*k) + +gk(m::Int, kx::Float64; β::Float64=10.0) = begin + m = 2 * m + 1 + ν = FermionicFreq(m) + iν = SparseIR.valueim(ν, β) + return 1 / (iν - ε(kx)) +end + +function gkb(b::Vector{Int}, n_m::Int, n_kx::Int; mu::Float64=0.0, β::Float64=10.0, layout::Symbol=:block) + @assert length(b) == n_m + n_kx + parts = split_bits(b; group_bits=[n_m, n_kx], layout=layout) + mb, kxb = parts + mbit = length(mb) + kxbit = length(kxb) + Nm = 2^mbit + Nkx = 2^kxbit + im = frombins(mb) + ikx = frombins(kxb) + @assert im ≤ Nm + @assert ikx ≤ Nkx + kx = 2π * (ikx - 1)/Nkx + m = (im - 1) + return gk(m, kx; mu=mu, β=β) +end + +function main() + nkx_bit = 10 + nm_bit = 10 + localdims = fill(2, nkx_bit + nm_bit) + f(v) = gkb(v, nm_bit, nkx_bit; layout=:interleave) + g = graph_TT(nkx_bit + nm_bit) + maxbonddim = 200 + kwargs = (maxbonddim = maxbonddim, maxiter = 100, tolerance = 1e-13) + center_vertex = (nkx_bit + nm_bit) ÷ 2 + # center_vertex = 1 + + ttn, ranks, errors = crossinterpolate(ComplexF64, f, localdims, g; center_vertex = center_vertex, kwargs...) + @show last(ranks) + + + g_tmp, original_entanglements, entanglements = ttnopt(ttn; ortho_vertex = center_vertex, max_degree = 2) + ttn_, ranks, errors = crossinterpolate(ComplexF64, f, localdims, g_tmp; center_vertex = center_vertex, kwargs...) + @show last(ranks) + + return 0 +end + +main() \ No newline at end of file diff --git a/samples/2d_model.jl b/samples/2d_model.jl new file mode 100644 index 0000000..8983ed6 --- /dev/null +++ b/samples/2d_model.jl @@ -0,0 +1,92 @@ +using Random +using LinearAlgebra +using TreeTCI: crossinterpolate, crossinterpolate_with_structuralsearch, crossinterpolate_with_3site_swapping +using NamedGraphs: NamedGraph, add_edge!, edges, src, dst +using QuanticsGrids +const QG = QuanticsGrids +using SparseIR +using ITensorNetworks +using TreeTCI: ttnopt +const ITN = ITensorNetworks +using ITensors +using NPZ +include("utils.jl") +include("graphs.jl") + + +ε(kx, ky) = -2*cos(kx) - 2*cos(ky) + +gk(kx::Float64, ky::Float64; mu::Float64=0.0, β::Float64=10.0) = begin + m = 1 + ν = FermionicFreq(m) + iν = SparseIR.valueim(ν, β) + return 1 / (iν - ε(kx, ky) + mu) +end + +function gkb(b::Vector{Int}, n_kx::Int, n_ky; mu::Float64=0.0, β::Float64=10.0, layout::Symbol=:block) + @assert length(b) == n_kx + n_ky + parts = split_bits(b; group_bits=[n_kx, n_ky], layout=layout) + kxb, kyb = parts + kxbit = length(kxb) + kybit = length(kyb) + Nkx = 2^kxbit + Nky = 2^kybit + iky = frombins(kyb) + ikx = frombins(kxb) + @assert ikx ≤ Nkx + @assert iky ≤ Nky + kx = 2π * (ikx - 1)/Nkx + ky = 2π * (iky - 1)/Nky + return gk(kx, ky; mu=mu, β=β) +end + + +function calculate_entanglements(entanglements) + ee = 0.0 + edges = [] + edge_vals = [] + for (key, value) in entanglements + ee += value + push!(edges, key) + push!(edge_vals, value) + end + @show ee / length(edges) + @show edges + @show edge_vals +end + +function main() + nkx_bit = 10 + nky_bit = 10 + localdims = fill(2, nkx_bit + nky_bit) + f(v) = gkb(v, nky_bit, nkx_bit; layout=:interleave, mu=-1.0, β=10.0) + g = graph_TT(nkx_bit + nky_bit) + maxbonddim = 200 + kwargs = (maxbonddim = maxbonddim, maxiter = 100, tolerance = 1e-13) + center_vertex = (nkx_bit + nky_bit) ÷ 2 + + ttn, ranks, errors = crossinterpolate(ComplexF64, f, localdims, g; center_vertex = center_vertex, kwargs...) + + g_tmp, original_entanglements, entanglements = ttnopt(ttn; ortho_vertex = center_vertex, max_degree = 1) + ttn_, ranks, errors = crossinterpolate(ComplexF64, f, localdims, g_tmp; center_vertex = center_vertex, kwargs...) + println("original_entanglements") + calculate_entanglements(original_entanglements) + @show last(ranks) + + println("--------------------------------") + println("ΔG = 2") + calculate_entanglements(entanglements) + @show last(ranks) + + g_tmp, original_entanglements, entanglements = ttnopt(ttn; ortho_vertex = center_vertex, max_degree = 2) + ttn_, ranks, errors = crossinterpolate(ComplexF64, f, localdims, g_tmp; center_vertex = center_vertex, kwargs...) + + println("--------------------------------") + println("ΔG = 3") + calculate_entanglements(entanglements) + @show last(ranks) + + return 0 +end + +main() \ No newline at end of file diff --git a/samples/graphs.jl b/samples/graphs.jl new file mode 100644 index 0000000..72ec755 --- /dev/null +++ b/samples/graphs.jl @@ -0,0 +1,9 @@ +using NamedGraphs: NamedGraph, add_edge! + +function graph_TT(R::Int) + g = NamedGraph(R) + for i in 1:R-1 + add_edge!(g, i, i+1) + end + return g +end \ No newline at end of file diff --git a/samples/sample_strucuralsearch.jl b/samples/sample_strucuralsearch.jl new file mode 100644 index 0000000..2644844 --- /dev/null +++ b/samples/sample_strucuralsearch.jl @@ -0,0 +1,49 @@ +using Random +using LinearAlgebra +using TreeTCI: crossinterpolate_with_3site_swapping +using Statistics +using Distributions +using NamedGraphs: NamedGraph, add_edge!, edges +using QuanticsGrids +const QG = QuanticsGrids + + +function build_problem() + Random.seed!(1234) + R = 4 + μ = zeros(3) + Σ = rand(LKJ(3, 50.0)) + dist = MvNormal(μ, Hermitian(Σ)) + f(x,y,z) = pdf(dist, [x,y,z]) + + grid = QG.DiscretizedGrid{3}(R, (-5,-5,-5), (5,5,5); unfoldingscheme=:interleaved) + fq = QG.quanticsfunction(Float64, grid, f) + + nsites = 3R + g = NamedGraph(nsites) + for i in 1:(nsites-1) + add_edge!(g, i, i+1) + end + localdims = fill(grid.base, nsites) + + return fq, localdims, g +end + +function main() + + fq, localdims, g = build_problem() + + kwargs = ( + maxbonddim = 10, + tolerance = 1e-10, + maxiter = 100, + ) + + tci = crossinterpolate_with_3site_swapping(Float64, fq, localdims, g; kwargs...) + + ttn = TreeTensorNetwork(tci.g, tci.sitetensors) + + return tci +end + +main() \ No newline at end of file diff --git a/samples/sample_treetci.jl b/samples/sample_treetci.jl new file mode 100644 index 0000000..b0c5d3a --- /dev/null +++ b/samples/sample_treetci.jl @@ -0,0 +1,27 @@ +using Test +using TreeTCI +import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge + +function main() + # make graph + g = NamedGraph(10) + add_edge!(g, 1, 3) + add_edge!(g, 2, 3) + add_edge!(g, 3, 5) + + add_edge!(g, 4, 5) + add_edge!(g, 5, 7) + add_edge!(g, 6, 7) + add_edge!(g, 7, 8) + add_edge!(g, 8, 9) + add_edge!(g, 8, 10) + + localdims = fill(2, length(vertices(g))) + f(v) = 1 / (1 + v' * v) + kwargs = (maxbonddim = 20, maxiter = 10) + ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g; kwargs...) + @show ttn([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), f([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + @show ttn([1, 2, 1, 2, 1, 2, 1, 2, 1, 2]), f([1, 2, 1, 2, 1, 2, 1, 2, 1, 2]) +end + +main() diff --git a/samples/sample_treetci2.jl b/samples/sample_treetci2.jl new file mode 100644 index 0000000..4628ce5 --- /dev/null +++ b/samples/sample_treetci2.jl @@ -0,0 +1,32 @@ +using Test +using TreeTCI +import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge + +function f_pairwise(v::Vector{Int}) + s = 0.0 + for i in 1:2:length(v)-1 + s += v[i] * v[i+1] + end + return 1 / (1+s) +end + + +function main() + # make graph + g = NamedGraph(8) + add_edge!(g, 1, 2) + add_edge!(g, 2, 3) + add_edge!(g, 3, 4) + add_edge!(g, 4, 5) + add_edge!(g, 5, 6) + add_edge!(g, 6, 7) + add_edge!(g, 7, 8) + + localdims = fill(8, length(vertices(g))) + f(v) = f_pairwise(v) + kwargs = (maxbonddim = 64, maxiter = 10) + ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g; kwargs...) + +end + +main() diff --git a/samples/utils.jl b/samples/utils.jl new file mode 100644 index 0000000..1d51a1b --- /dev/null +++ b/samples/utils.jl @@ -0,0 +1,111 @@ +""" + split_bits(b; group_bits, layout=:block) + +Split a bit sequence b of length `length(b)` (where each element is 1/2) into groups +with bit counts specified by `group_bits`. + +Arguments +- b::AbstractVector{<:Integer} : Bit sequence of 1/2 (MSB-first) +- group_bits::AbstractVector{<:Integer} : Number of bits in each group (e.g., [mbit, kxbit, …]) +- layout::Symbol = :block | :interleave + - :block → Assumes b = [ grp1..., grp2..., … ] concatenation + - :interleave → Assumes b = [ g1_1, g2_1, …, gG_1, g1_2, g2_2, … ] order + Assumes all groups have equal bit counts + +Returns +- Vector{Vector{Int}} : Bit sequences for each group (each element is 1/2) +""" +function split_bits(b::AbstractVector{<:Integer}; group_bits::AbstractVector{<:Integer}, layout::Symbol=:block) + @assert all(x -> x ≥ 0, group_bits) "group_bits must be nonnegative" + G = length(group_bits) + total = sum(group_bits) + @assert length(b) == total "length(b) must equal sum(group_bits)" + + if layout === :block + parts = Vector{Vector{Int}}(undef, G) + start = 1 + @inbounds for g in 1:G + len = group_bits[g] + parts[g] = collect(b[start:start+len-1]) + start += len + end + return parts + + elseif layout === :interleave + @assert length(unique(group_bits)) == 1 "interleave layout requires all group lengths equal" + L = group_bits[1] # 各グループの長さ + @assert L * G == length(b) + parts = [Vector{Int}(undef, L) for _ in 1:G] + # b = [ g1_1, g2_1, …, gG_1, g1_2, g2_2, …, gG_2, … ] + @inbounds for j in 1:L + base = (j-1)*G + for g in 1:G + parts[g][j] = b[base + g] + end + end + return parts + + else + error("layout must be :block or :interleave") + end +end + + +""" + join_bits(parts; layout=:block) + +Inverse of `split_bits`. Combines bit sequences `parts` from each group (1/2 1-based) +to return a single bit sequence b. + +Arguments +- parts::Vector{<:AbstractVector{<:Integer}} : Bit sequences for each group +- layout::Symbol = :block | :interleave + +Note +- For :interleave, assumes all groups have equal length. +""" +function join_bits(parts::Vector{<:AbstractVector{<:Integer}}; layout::Symbol=:block) + G = length(parts) + if layout === :block + return vcat(parts...) + elseif layout === :interleave + Lset = length.(parts) + @assert length(unique(Lset)) == 1 ":interleave requires equal group lengths" + L = Lset[1] + out = Vector{Int}(undef, L*G) + # out = [ g1_1, g2_1, …, gG_1, g1_2, g2_2, …, gG_2, … ] + @inbounds for j in 1:L + base = (j-1)*G + for g in 1:G + out[base + g] = parts[g][j] + end + end + return out + else + error("layout must be :block or :interleave") + end +end + + +function tobins(i, nbit) + @assert 1 ≤ i ≤ 2^nbit + mask = 1 << (nbit-1) + bin = ones(Int, nbit) + for n in 1:nbit + bin[n] = (mask & (i-1)) >> (nbit-n) + 1 + mask = mask >> 1 + end + return bin +end + +function frombins(bin) + @assert all(1 .≤ bin .≤ 2) + nbit = length(bin) + i = 1 + tmp = 2^(nbit-1) + for n in eachindex(bin) + i += tmp * (bin[n] -1) + tmp = tmp >> 1 + end + return i +end diff --git a/src/TreeTCI.jl b/src/TreeTCI.jl index 759d43c..590aa40 100644 --- a/src/TreeTCI.jl +++ b/src/TreeTCI.jl @@ -1,26 +1,14 @@ module TreeTCI -import Graphs -import NamedGraphs: - NamedGraph, - NamedEdge, - is_directed, - outneighbors, - has_edge, - edges, - vertices, - src, - dst, - namedgraph_dijkstra_shortest_paths -import TensorCrossInterpolation as TCI -import SimpleTensorNetworks: - TensorNetwork, IndexedArray, Index, complete_contraction, getindex, contract -import Random: shuffle +include("imports.jl") include("treegraph_utils.jl") include("simpletci.jl") +include("newstructureproposer.jl") include("pivotcandidateproposer.jl") include("sweep2sitepathproposer.jl") include("simpletci_optimize.jl") include("simpletci_tensors.jl") include("treetensornetwork.jl") +include("structuralsearch.jl") +include("ttnopt.jl") end diff --git a/src/abstracttreetensornetwork.jl b/src/abstracttreetensornetwork.jl new file mode 100644 index 0000000..acfe542 --- /dev/null +++ b/src/abstracttreetensornetwork.jl @@ -0,0 +1,39 @@ +abstract type AbstractTreeTensorNetwork{V} <: Function end + +""" + function evaluate( + ttn::TreeTensorNetwork{V}, + indexset::Union{AbstractVector{Int}, NTuple{N, Int}} + )::V where {V} + +Evaluates the tensor train `tt` at indices given by `indexset`. +""" +function evaluate( + ttn::AbstractTreeTensorNetwork{V}, + indexset::Union{AbstractVector{Int},NTuple{N,Int}}, +)::V where {N,V} + if length(indexset) != length(ttn.sitetensors) + throw( + ArgumentError( + "To evaluate a tt of length $(length(ttn)), you have to provide $(length(ttn)) indices, but there were $(length(indexset)).", + ), + ) + end + sitetensors = IndexedArray[] + for (Tinfo, i) in zip(ttn.sitetensors, indexset) + T, edges = Tinfo + inds = (i, ntuple(_ -> :, ndims(T) - 1)...) + T = T[inds...] + indexs = [ + Index(size(T)[j], "$(src(edges[j]))=>$(dst(edges[j]))") for j = 1:length(edges) + ] + t = IndexedArray(T, indexs) + push!(sitetensors, t) + end + tn = TensorNetwork(sitetensors) + return only(complete_contraction(tn)) +end + +function (ttn::AbstractTreeTensorNetwork{V})(indexset) where {V} + return evaluate(ttn, indexset) +end diff --git a/src/imports.jl b/src/imports.jl new file mode 100644 index 0000000..9709932 --- /dev/null +++ b/src/imports.jl @@ -0,0 +1,25 @@ +using Random +using Combinatorics +using Graphs: simplecycles_limited_length, has_edge, SimpleGraph, center, steiner_tree +using NamedGraphs: + NamedGraph, + NamedEdge, + is_cyclic, + is_directed, + neighbors, + outneighbors, + has_edge, + edges, + vertices, + namedgraph_dijkstra_shortest_paths +using NamedGraphs.GraphsExtensions: + src, + dst, + is_connected, + degree, + add_vertices!, add_vertex!, rem_vertices!, rem_vertex!, + rem_edge!, add_edge! +import TensorCrossInterpolation as TCI +import SimpleTensorNetworks: TensorNetwork, IndexedArray, Index, complete_contraction, getindex, contract +import DataGraphs: underlying_graph + diff --git a/src/newstructureproposer.jl b/src/newstructureproposer.jl new file mode 100644 index 0000000..6c092c2 --- /dev/null +++ b/src/newstructureproposer.jl @@ -0,0 +1,28 @@ +""" +Abstract type for structure proposal methods +""" +abstract type AbstractNewStructureProposer end + +struct NewStructureLocalSwap <: AbstractNewStructureProposer end + +struct NewStructureGlobalSwap <: AbstractNewStructureProposer end + +function generate_new_structure( + ::NewStructureLocalSwap, + tci::SimpleTCI{ValueType}, + edge::NamedEdge, +) where {ValueType} + edge in edges(tci.g) || error("Edge $edge not in graph") + vs = src(edge) => dst(edge) + return swap_2site(tci.g, vs) +end + +function generate_new_structure( + ::NewStructureGlobalSwap, + tci::SimpleTCI{ValueType}, + vs::Pair{Int, Int}, +) where {ValueType} + first(vs) in vertices(tci.g) || error("Vertex $first(vs) not in graph") + last(vs) in vertices(tci.g) || error("Vertex $last(vs) not in graph") + return swap_2site(tci.g, vs) +end diff --git a/src/pivotcandidateproposer.jl b/src/pivotcandidateproposer.jl index 5a65ca2..3659493 100644 --- a/src/pivotcandidateproposer.jl +++ b/src/pivotcandidateproposer.jl @@ -44,10 +44,9 @@ function generate_pivot_candidates( Iset = kronecker(Ipivots, Isite_index, tci.localdims[vp]) Jset = kronecker(Jpivots, Jsite_index, tci.localdims[vq]) - extraIJset = if length(tci.IJset_history) > 0 - extraIJset = tci.IJset_history[end] - else - Dict(key => MultiIndex[] for key in keys(tci.IJset)) + extraIJset = tci.IJset + for (key, pivots) in tci.converged_IJset + extraIJset[key] = pivots end Icombined = union(Iset, extraIJset[Ikey]) diff --git a/src/simpletci.jl b/src/simpletci.jl index 963d1dc..4d49f29 100644 --- a/src/simpletci.jl +++ b/src/simpletci.jl @@ -34,12 +34,12 @@ addglobalpivots!(tci, [[1,1,1], [2,1,1]]) """ mutable struct SimpleTCI{ValueType} IJset::Dict{SubTreeVertex,Vector{MultiIndex}} + converged_IJset::Dict{SubTreeVertex,Vector{MultiIndex}} localdims::Vector{Int} g::NamedGraph bonderrors::Dict{NamedEdge,Float64} pivoterrors::Vector{Float64} maxsamplevalue::Float64 - IJset_history::Vector{Dict{SubTreeVertex,Vector{MultiIndex}}} function SimpleTCI{ValueType}(localdims::Vector{Int}, g::NamedGraph) where {ValueType} n = length(localdims) @@ -47,20 +47,23 @@ mutable struct SimpleTCI{ValueType} n == length(vertices(g)) || error( "The number of vertices in the graph must be equal to the length of localdims.", ) - !Graphs.is_cyclic(g) || + !is_cyclic(g) || error("SimpleTCI is not supported for loopy tensor network.") # assign the key for each bond - bonderrors = Dict(e => 0.0 for e in edges(g)) + bonderrors = Dict(e => typemax(Float64) for e in edges(g)) + + !is_cyclic(g) || + error("TreeTensorNetwork is not supported for loopy tensor network.") new{ValueType}( Dict{SubTreeVertex,Vector{MultiIndex}}(), # IJset + Dict{SubTreeVertex,Vector{MultiIndex}}(), # converged_IJset localdims, g, bonderrors, Float64[], 0.0, # maxsamplevalue - Vector{Dict{SubTreeVertex,Vector{MultiIndex}}}(), # IJset_history ) end end @@ -115,6 +118,7 @@ function addglobalpivots!( nothing end + function pushunique!(collection, item) if !(item in collection) push!(collection, item) diff --git a/src/simpletci_optimize.jl b/src/simpletci_optimize.jl index 04a8f3f..7417fbc 100644 --- a/src/simpletci_optimize.jl +++ b/src/simpletci_optimize.jl @@ -48,7 +48,9 @@ function optimize!( loginterval::Int = 10, normalizeerror::Bool = true, ncheckhistory::Int = 3, -) where {ValueType} + ) where {ValueType} + + # Histories of properties for checking convergence. errors = Float64[] ranks = Int[] @@ -62,7 +64,6 @@ function optimize!( ) end - globalpivots = MultiIndex[] for iter = 1:maxiter errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 abstol = tolerance * errornormalization @@ -74,34 +75,32 @@ function optimize!( sweep2site!( tci, - f, - 2; + f; abstol = abstol, maxbonddim = maxbonddim, verbosity = verbosity, sweepstrategy = sweepstrategy, pivotstrategy = pivotstrategy, ) - if verbosity > 0 && length(globalpivots) > 0 && mod(iter, loginterval) == 0 - abserr = [abs(evaluate(tci, p) - f(p)) for p in globalpivots] - nrejections = length(abserr .> abstol) - if nrejections > 0 - println( - " Rejected $(nrejections) global pivots added in the previous iteration, errors are $(abserr)", - ) - flush(stdout) - end - end - push!(errors, last(pivoterror(tci))) - if verbosity > 1 - println( - " Walltime $(1e-9*(time_ns() - tstart)) sec: start searching global pivots", - ) - flush(stdout) + push!(ranks, rank(tci)) + push!(errors, pivoterror(tci)) + + if convergencecriterion( + ranks, + errors, + maxbonddim, + tolerance, + ncheckhistory + ) + if verbosity > 1 + println("Converged at $(iter)th-sweep.") + end + break end end - + + tci.converged_IJset = deepcopy(tci.IJset) errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 return ranks, errors ./ errornormalization end @@ -111,8 +110,7 @@ end """ function sweep2site!( tci::SimpleTCI{ValueType}, - f, - niter::Int; + f; abstol::Float64 = 1e-8, maxbonddim::Int = typemax(Int), sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer(), @@ -122,24 +120,18 @@ function sweep2site!( edge_path = generate_sweep2site_path(sweepstrategy, tci) - for _ = 1:niter - extraIJset = Dict(key => MultiIndex[] for key in keys(tci.IJset)) - - push!(tci.IJset_history, deepcopy(tci.IJset)) - - flushpivoterror!(tci) + flushpivoterror!(tci) - for edge in edge_path - updatepivots!( - tci, - edge, - f; - abstol = abstol, - maxbonddim = maxbonddim, - pivotstrategy = pivotstrategy, - verbosity = verbosity, - ) - end + for edge in edge_path + updatepivots!( + tci, + edge, + f; + abstol = abstol, + maxbonddim = maxbonddim, + pivotstrategy = pivotstrategy, + verbosity = verbosity, + ) end nothing @@ -224,6 +216,10 @@ function updatepivoterror!(tci::SimpleTCI{T}, errors::AbstractVector{Float64}) w nothing end +function rank(tci::SimpleTCI{ValueType}) where {ValueType} + return maximum(length(IJset) for IJset in values(tci.IJset)) +end + function pivoterror(tci::SimpleTCI{T}) where {T} return maxbonderror(tci) end @@ -231,3 +227,20 @@ end function maxbonderror(tci::SimpleTCI{T}) where {T} return maximum(values(tci.bonderrors)) end + +function convergencecriterion( + ranks::AbstractVector{Int}, + errors::AbstractVector{Float64}, + maxbonddim::Int, + tolerance::Float64, + ncheckhistory::Int, +)::Bool + if length(errors) < ncheckhistory + return false + end + lastranks = last(ranks, ncheckhistory) + return ( + all(last(errors, ncheckhistory) .< tolerance) && + minimum(lastranks) == lastranks[end] + ) || all(lastranks .>= maxbonddim) +end \ No newline at end of file diff --git a/src/simpletci_tensors.jl b/src/simpletci_tensors.jl index a5d9ea7..cf9df99 100644 --- a/src/simpletci_tensors.jl +++ b/src/simpletci_tensors.jl @@ -18,7 +18,7 @@ function fillsitetensors( tci::SimpleTCI{ValueType}, f; - center_vertex::Int = 0, + center_vertex::Int = 1, ) where {ValueType} sitetensors = @@ -70,8 +70,8 @@ function sitetensor( return reshape( T, tci.localdims[site], - [length(tci.IJset[key]) for key in Inkeys]..., - [length(tci.IJset[key]) for key in Outkeys]..., + [length(tci.converged_IJset[key]) for key in Inkeys]..., + [length(tci.converged_IJset[key]) for key in Outkeys]..., ) end @@ -85,11 +85,11 @@ function sitetensor( ) where {ValueType} Inkeys, Outkeys = InOutkeys L = length(tci.localdims) - Pi1 = filltensor(ValueType, f, tci.localdims, tci.IJset, Inkeys, Outkeys, Val(1)) + Pi1 = filltensor(ValueType, f, tci.localdims, tci.converged_IJset, Inkeys, Outkeys, Val(1)) Pi1 = reshape( Pi1, - prod(vcat([tci.localdims[site]], [length(tci.IJset[key]) for key in Inkeys])), - prod([length(tci.IJset[key]) for key in Outkeys]), + prod(vcat([tci.localdims[site]], [length(tci.converged_IJset[key]) for key in Inkeys])), + prod([length(tci.converged_IJset[key]) for key in Outkeys]), ) updatemaxsample!(tci, Pi1) @@ -106,17 +106,17 @@ function sitetensor( end P = reshape( - filltensor(ValueType, f, tci.localdims, tci.IJset, [I1key], Outkeys, Val(0)), - length(tci.IJset[I1key]), - prod([length(tci.IJset[key]) for key in Outkeys]), + filltensor(ValueType, f, tci.localdims, tci.converged_IJset, [I1key], Outkeys, Val(0)), + length(tci.converged_IJset[I1key]), + prod([length(tci.converged_IJset[key]) for key in Outkeys]), ) - length(tci.IJset[I1key]) == sum([length(tci.IJset[key]) for key in Outkeys]) || error("Pivot matrix at bond $(site) is not square!") + length(tci.converged_IJset[I1key]) == sum([length(tci.converged_IJset[key]) for key in Outkeys]) || error("Pivot matrix at bond $(site) is not square!") Tmat = transpose(transpose(P) \ transpose(Pi1)) T = reshape( Tmat, tci.localdims[site], - [length(tci.IJset[key]) for key in Inkeys]..., - [length(tci.IJset[key]) for key in Outkeys]..., + [length(tci.converged_IJset[key]) for key in Inkeys]..., + [length(tci.converged_IJset[key]) for key in Outkeys]..., ) return T end diff --git a/src/structuralsearch.jl b/src/structuralsearch.jl new file mode 100644 index 0000000..311f0b1 --- /dev/null +++ b/src/structuralsearch.jl @@ -0,0 +1,358 @@ +function crossinterpolate_with_3site_swapping( + ::Type{ValueType}, + f, + localdims::Union{Vector{Int},NTuple{N,Int}}, + g::NamedGraph, + nsweeps:: Int = 100, + origin_edge = nothing, + initialpivots::Vector{MultiIndex} = [ones(Int, length(localdims))]; + kwargs..., +) where {ValueType,N} + + tci = SimpleTCI{ValueType}(f, localdims, g, initialpivots) + ranks, errors = optimize!(tci, f; kwargs...) + n = length(vertices(tci.g)) + + + g_tmp = deepcopy(tci.g) + + if origin_edge == nothing + d = n + for e in edges(tci.g) + p, q = separatevertices(tci.g, e) + Iset = length(subtreevertices(tci.g, p => q)) + Jset = length(subtreevertices(tci.g, q => p)) + d_tmp = abs(Iset - Jset) + if d_tmp < d + d = d_tmp + origin_edge = e + end + end + end + + id2edge = collect(edges(tci.g)) + edge2id = Dict{NamedEdge,Int}(e => i for (i,e) in enumerate(id2edge)) + center_edge = origin_edge + center_edge_id = edge2id[center_edge] + origin_edge_id = edge2id[origin_edge] + previous_center_edge_id = center_edge_id + + for i = 1:nsweeps + # Init flags + flags = Dict(k => 0 for k in 1:length(id2edge)) + while true + + tci, id2edge = optimize_with_3site_swapping(tci, f, initialpivots, id2edge[center_edge_id], id2edge[previous_center_edge_id], id2edge, edge2id; kwargs...) + edge2id = Dict(e => i for (i, e) in enumerate(id2edge)) + + # update next center edge + candidates = candidateedges(tci.g, id2edge[previous_center_edge_id]) + candidates = [e for e in candidates if flags[edge2id[e]] == 0] + + # If candidates is empty, exit while loop + if isempty(candidates) + break + end + + distances = distanceedges(tci.g, id2edge[origin_edge_id]) + max_distance = maximum(distances[e] for e in candidates) + candidates = filter(e -> distances[e] == max_distance, candidates) + + center_edge_ = first(candidates) + + p, q = separatevertices(tci.g, id2edge[previous_center_edge_id]) + v = center_edge_ in adjacentedges(tci.g, p) ? q : p # + incomings = [edge for edge in adjacentedges(tci.g, v) if edge != id2edge[previous_center_edge_id]] + + # Update flags - ID management + center_edge_id = edge2id[center_edge] + if all(flags[edge2id[e]] == 1 for e in incomings) && center_edge_id != origin_edge_id + flags[center_edge_id] = 1 + end + + # update center edge + center_edge = center_edge_ + + # Structural search + tci, id2edge = optimize_with_3site_swapping(tci, f, initialpivots, id2edge[center_edge_id], id2edge[previous_center_edge_id], id2edge, edge2id; kwargs...) + edge2id = Dict(e => i for (i, e) in enumerate(id2edge)) + + previous_center_edge_id = center_edge_id + end + if tci.g == g_tmp + @show "converged" + break + end + g_tmp = deepcopy(tci.g) # update g_tmp + end + sitetensors, center_vertex = fillsitetensors(tci, f) + + ranks, errors = optimize!(tci, f; kwargs...) + return TreeTensorNetwork(tci.g, sitetensors), ranks, errors +end + +function crossinterpolate_with_structuralsearch( + ::Type{ValueType}, + f, + localdims::Union{Vector{Int},NTuple{N,Int}}, + g::NamedGraph, + max_degree::Int = 1, + nsweeps:: Int = 100, + origin_edge = nothing, + initialpivots::Vector{MultiIndex} = [ones(Int, length(localdims))]; + kwargs..., +) where {ValueType,N} + + tci = SimpleTCI{ValueType}(f, localdims, g, initialpivots) + ranks, errors = optimize!(tci, f; kwargs...) + n = length(vertices(tci.g)) + + g_tmp = deepcopy(tci.g) + + if origin_edge == nothing + d = n + for e in edges(tci.g) + p, q = separatevertices(tci.g, e) + Iset = length(subtreevertices(tci.g, p => q)) + Jset = length(subtreevertices(tci.g, q => p)) + d_tmp = abs(Iset - Jset) + if d_tmp < d + d = d_tmp + origin_edge = e + end + end + end + + id2edge = collect(edges(tci.g)) + edge2id = Dict{NamedEdge,Int}(e => i for (i,e) in enumerate(id2edge)) + center_edge = origin_edge + origin_edge_id = edge2id[origin_edge] + + for i = 1:nsweeps + # Init flags + flags = Dict(k => 0 for k in 1:length(id2edge)) + while true + + # Structural search + tci, id2edge = optimize_with_localmanipulation(tci, f, initialpivots, center_edge, max_degree, id2edge, edge2id; kwargs...) + edge2id = Dict(e => i for (i, e) in enumerate(id2edge)) + + # update next center edge + candidates = candidateedges(tci.g, center_edge) + candidates = [e for e in candidates if flags[edge2id[e]] == 0] + + # If candidates is empty, exit while loop + if isempty(candidates) + break + end + + distances = distanceedges(tci.g, id2edge[origin_edge_id]) + max_distance = maximum(distances[e] for e in candidates) + candidates = filter(e -> distances[e] == max_distance, candidates) + + center_edge_ = first(candidates) + + p, q = separatevertices(tci.g, center_edge) + v = center_edge_ in adjacentedges(tci.g, p) ? q : p # + incomings = [edge for edge in adjacentedges(tci.g, v) if edge != center_edge] + + # Update flags - ID management + center_edge_id = edge2id[center_edge] + if all(flags[edge2id[e]] == 1 for e in incomings) && center_edge_id != origin_edge_id + flags[center_edge_id] = 1 + end + + # update center edge + center_edge = center_edge_ + end + if tci.g == g_tmp + @show "converged" + break + end + g_tmp = deepcopy(tci.g) # update g_tmp + end + sitetensors = fillsitetensors(tci, f) + + ranks, errors = optimize!(tci, f; kwargs...) + return TreeTensorNetwork(tci.g, sitetensors), ranks, errors +end + +# Optimization with local 3-site tensor swapping +function optimize_with_3site_swapping( + tci::SimpleTCI{ValueType}, + f, + initialpivots::Vector{MultiIndex}, + center_edge::NamedEdge, + adjacent_edge::NamedEdge, + id2edge, + edge2id; + kwargs... +) where {ValueType} + best_err = maximum(values(tci.bonderrors)) + + tci_best = tci + id2edge_best = id2edge + + + # center_edgeとadjacent_edgeからp,q,rを抽出(qが両方に接続) + tmp1, tmp2 = src(center_edge), dst(center_edge) + tmp3, tmp4 = src(adjacent_edge), dst(adjacent_edge) + + # qが両方のエッジに接続するようにp, q, rを決定 + if tmp1 in (tmp3, tmp4) + q = tmp1 + p = tmp2 + r = tmp1 == tmp3 ? tmp4 : tmp3 + elseif tmp2 in (tmp3, tmp4) + q = tmp2 + p = tmp1 + r = tmp2 == tmp3 ? tmp4 : tmp3 + else + @warn "No common node between center_edge and adjacent_edge" + return tci, id2edge + end + + @assert r != p && r != q + + # get the subtree sets that are connected to each node + p_sub = filter(x -> x ∉ (q, r), neighbors(tci.g, p)) + q_sub = filter(x -> x ∉ (p, r), neighbors(tci.g, q)) + r_sub = filter(x -> x ∉ (p, q), neighbors(tci.g, r)) + + for (a, b, c) in permutations((p, q, r)) + g_new = deepcopy(tci.g) + id2edge_new = deepcopy(id2edge) + + for (v1, v2) in ((p, q), (q, r)) + e_old = NamedEdge(min(v1, v2) => max(v1, v2)) + if has_edge(g_new, e_old) + e_old_id = edge2id[e_old] + rem_edge!(g_new, e_old) + end + end + + e_ab = NamedEdge(min(a, b) => max(a, b)) + e_bc = NamedEdge(min(b, c) => max(b, c)) + add_edge!(g_new, e_ab) + add_edge!(g_new, e_bc) + + # 3ノード間エッジのID更新(optimize_with_localmanipulationと同じパターン) + if haskey(edge2id, NamedEdge(min(p, q) => max(p, q))) + e_pq_id = edge2id[NamedEdge(min(p, q) => max(p, q))] + id2edge_new[e_pq_id] = e_ab + end + if haskey(edge2id, NamedEdge(min(q, r) => max(q, r))) + e_qr_id = edge2id[NamedEdge(min(q, r) => max(q, r))] + id2edge_new[e_qr_id] = e_bc + end + + # サブグラフの再接続(重複エッジを避ける) + # 各ノードのサブグラフを新しい親に再接続 + for (old_node, new_node) in [(p, a), (q, b), (r, c)] + if old_node != new_node # ノードが変わった場合のみ再接続 + # 元のノードからサブグラフを取得 + old_sub = filter(x -> x ∉ (p, q, r), neighbors(tci.g, old_node)) + + # 新しいノードにサブグラフを再接続 + for sub_node in old_sub + e_old = NamedEdge(min(old_node, sub_node) => max(old_node, sub_node)) + if has_edge(g_new, e_old) && haskey(edge2id, e_old) + e_old_id = edge2id[e_old] + rem_edge!(g_new, e_old) + + # 新しいエッジを追加 + e_new = NamedEdge(min(new_node, sub_node) => max(new_node, sub_node)) + id2edge_new[e_old_id] = e_new + add_edge!(g_new, e_new) + end + end + end + end + + tci_tmp = SimpleTCI{ValueType}(f, tci.localdims, g_new, initialpivots) + # tci_tmp.converged_IJset = tci.converged_IJset + _, _ = optimize!(tci_tmp, f; kwargs...) + err = maximum(values(tci_tmp.bonderrors)) + + # 7. update best + if err < best_err + best_err = err + tci_best = deepcopy(tci_tmp) + id2edge_best = deepcopy(id2edge_new) + end + end + + return tci_best, id2edge_best +end + +# Optimization with local 2-site tensor manipulation +function optimize_with_localmanipulation( + tci::SimpleTCI{ValueType}, + f, + initialpivots::Vector{MultiIndex}, + center_edge::NamedEdge, + max_degree::Int, + id2edge, + edge2id; + kwargs... +) where {ValueType} + + best_err = maximum(values(tci.bonderrors)) + tci_best = tci + id2edge_best = id2edge + + # calculate the two ends of the split object and its child node list + p, q = src(center_edge), dst(center_edge) + subI = filter(x->x!=q, neighbors(tci.g,p)) + subJ = filter(x->x!=p, neighbors(tci.g,q)) + children = vcat(subI, subJ) + n = length(children) + limit = max_degree + + # k nodes are left (p side), the rest are right (q side) + for k in 0:n + if k ≤ limit && (n-k) ≤ limit + for left in combinations(children, k) + leftset = Set(left) + + # copy the graph and replace one by one + g_new = deepcopy(tci.g) + id2edge_new = deepcopy(id2edge) + rem_edge!(g_new, center_edge) + + for v in children + # remove the old parent + old_parent = (v in subI ? p : q) + + e_old = old_parent < v ? NamedEdge(old_parent=>v) : NamedEdge(v=>old_parent) + e_old_id = edge2id[e_old] + + rem_edge!(g_new, e_old) + # connect the new parent + new_parent = (v in leftset ? p : q) + e_new = new_parent < v ? NamedEdge(new_parent=>v) : NamedEdge(v=>new_parent) + + id2edge_new[e_old_id] = e_new + add_edge!(g_new, e_new) + end + add_edge!(g_new, center_edge) + + # optimize + tci_tmp = SimpleTCI{ValueType}(f, tci.localdims, g_new, initialpivots) + # tci_tmp.converged_IJset = tci.converged_IJset + _, _ = optimize!(tci_tmp, f; kwargs...) + err = maximum(values(tci_tmp.bonderrors)) + + # update best + if err < best_err && !(err ≈ best_err) + best_err = err + tci_best = deepcopy(tci_tmp) + id2edge_best = deepcopy(id2edge_new) + end + + end + end + end + + return tci_best, id2edge_best +end \ No newline at end of file diff --git a/src/sweep2sitepathproposer.jl b/src/sweep2sitepathproposer.jl index ae1139b..5dce04b 100644 --- a/src/sweep2sitepathproposer.jl +++ b/src/sweep2sitepathproposer.jl @@ -44,14 +44,14 @@ LocalAdjacent strategy that runs through within all indices of site tensor accor function generate_sweep2site_path( ::LocalAdjacentSweep2sitePathProposer, tci::SimpleTCI{ValueType}; - origin_edge = undef, + origin_edge = nothing, ) where {ValueType} edge_path = Vector{NamedEdge}() n = length(vertices(tci.g)) # choose the center bond id. - if origin_edge == undef + if origin_edge == nothing d = n for e in edges(tci.g) p, q = separatevertices(tci.g, e) diff --git a/src/treegraph_utils.jl b/src/treegraph_utils.jl index e0c4a86..9f37fee 100644 --- a/src/treegraph_utils.jl +++ b/src/treegraph_utils.jl @@ -82,3 +82,45 @@ function distanceedges(g::NamedGraph, edge::NamedEdge)::Dict{NamedEdge,Int} distances[edge] = 0 return distances end + +function swap_2site(g_old::NamedGraph, vs::Pair{Int, Int}) + g = deepcopy(g_old) + p, q = vs + p_neighbors = neighbors(g, p) + q_neighbors = neighbors(g, q) + + rem_vertices!(g, [p, q]) + add_vertex!(g, p) + add_vertex!(g, q) + + for p_i in p_neighbors + p_i == q && continue # avoid self-loop + add_edge!(g, NamedEdge(p_i => q)) + end + + for q_i in q_neighbors + q_i == p && continue # avoid self-loop + add_edge!(g, NamedEdge(q_i => p)) + end + + if has_edge(g_old, NamedEdge(p => q)) + add_edge!(g, NamedEdge(p => q)) + end + + return g +end + +function add_subtree( + g_old::NamedGraph, + v::Int, + edge::NamedEdge, +) + g = deepcopy(g_old) + p, q = separatevertices(g, edge) + p_regions = subtreevertices(g, q => p) + q_regions = subtreevertices(g, p => q) + parent = v in p_regions ? q : p + rem_edge!(g, edge) + add_edge!(g, NamedEdge(parent => v)) + return g +end diff --git a/src/treetensornetwork.jl b/src/treetensornetwork.jl index 1499d90..43836f9 100644 --- a/src/treetensornetwork.jl +++ b/src/treetensornetwork.jl @@ -5,7 +5,7 @@ mutable struct TreeTensorNetwork{ValueType} g::NamedGraph, sitetensors::Vector{Pair{Array{ValueType},Vector{NamedEdge}}}, ) where {ValueType} - !Graphs.is_cyclic(g) || + !is_cyclic(g) || error("TreeTensorNetwork is not supported for loopy tensor network.") ttntensors = Vector{IndexedArray}() for (i, (T, edges)) in enumerate(sitetensors) @@ -82,11 +82,12 @@ function crossinterpolate( localdims::Union{Vector{Int},NTuple{N,Int}}, g::NamedGraph, initialpivots::Vector{MultiIndex} = [ones(Int, length(localdims))]; + center_vertex::Int = 1, kwargs..., ) where {ValueType,N} tci = SimpleTCI{ValueType}(f, localdims, g, initialpivots) ranks, errors = optimize!(tci, f; kwargs...) - sitetensors = fillsitetensors(tci, f) + sitetensors = fillsitetensors(tci, f; center_vertex = center_vertex) return TreeTensorNetwork(tci.g, sitetensors), ranks, errors end @@ -121,3 +122,18 @@ end function Base.length(ttn::TreeTensorNetwork) return length(vertices(ttn.tensornetwork.data_graph)) end + +# Add method to get graph structure from TreeTensorNetwork +function get_graph(ttn::TreeTensorNetwork) + return ttn.tensornetwork.data_graph +end + +# Add method to get vertices from TreeTensorNetwork +function get_vertices(ttn::TreeTensorNetwork) + return vertices(ttn.tensornetwork.data_graph) +end + +# Add method to get edges from TreeTensorNetwork +function get_edges(ttn::TreeTensorNetwork) + return edges(ttn.tensornetwork.data_graph) +end diff --git a/src/ttnopt.jl b/src/ttnopt.jl new file mode 100644 index 0000000..4c765c6 --- /dev/null +++ b/src/ttnopt.jl @@ -0,0 +1,191 @@ +using ITensors +using ITensorNetworks +const ITN = ITensorNetworks + +function entanglements_entropy(s) + s = diag(s) + s2 = s.^2 + s2 = s2 / sum(s2) + s2 = s2[s2 .> 0.0] + return -sum(s2 .* log.(s2)) +end + +function ttnopt( + ttn::TreeTensorNetwork, + nsweeps::Int=50; + ortho_vertex::Int=1, + max_degree::Int = 1, + T0::Float64 = 0.0, +) + ttn = convert_ITensorNetwork(ttn, ortho_vertex) + normalize!(ttn) + neighbor_vertices = neighbors(ttn, ortho_vertex) + next_vertex = first(neighbor_vertices) + origin_edge = NamedEdge(min(ortho_vertex, next_vertex) => max(ortho_vertex, next_vertex)) + center_edge = origin_edge + + flag_indices = [tags(first(ITN.linkinds(ttn, e))) for e in edges(ttn)] + origin_flag_index = tags(first(ITN.linkinds(ttn, origin_edge))) + + edge_list = [[src(e), dst(e)] for e in edges(ttn)] + + original_entanglements = Dict() + final_entanglements = Dict() + + for sweep = 0:nsweeps + flags = Dict(flag_indices[i] => 0 for i in 1:length(flag_indices)) + final_entanglements = Dict() + while true + g = ttn.tensornetwork.data_graph.underlying_graph + p, q = separatevertices(g, center_edge) + linkind = first(ITN.linkinds(ttn, center_edge)) + maxbonddim = dim(linkind) + tag = tags(linkind) + siteindices = [ITN.siteinds(ttn, p); ITN.siteinds(ttn, q)] + + ψ = ITensors.contract(ttn[p], ttn[q]) + if sweep == 0 + leftinds = filter(ind -> ind != ITN.siteinds(ttn, p), inds(ttn[p])) + _, s, _ = ITensors.svd(ψ, leftinds; cutoff = 0.0) + ee = entanglements_entropy(s) + original_entanglements[src(center_edge), dst(center_edge)] = ee + else + entanglements, leftinds_list = propose_structure(ψ, siteindices, max_degree) + index = decide_structure(entanglements, sweep, nsweeps; T0 = T0) + leftinds = leftinds_list[index] + ee = entanglements[index] + final_entanglements[src(center_edge), dst(center_edge)] = ee + end + + u, s, v = ITensors.svd(ψ, leftinds; maxdim = maxbonddim, lefttags = tag, righttags = tag) + ttn[p] = u + ttn[q] = v * s + + # update next center edge + g = ttn.tensornetwork.data_graph.underlying_graph + + candidates = candidateedges(g, center_edge) + candidates = [e for e in candidates if flags[tags(first(ITN.linkinds(ttn, e)))] == 0] + + # If candidates is empty, exit while loop + if isempty(candidates) + break + end + + same_index_edge = first([e for e in edges(ttn) if tags(first(ITN.linkinds(ttn, e))) == origin_flag_index]) + distances = distanceedges(g, same_index_edge) + max_distance = maximum(distances[e] for e in candidates) + candidates = filter(e -> distances[e] == max_distance, candidates) + center_edge_ = first(candidates) + prev_vertex, next_vertex = center_edge_ in adjacentedges(g, p) ? (q, p) : (p, q) # + incomings = [first(ITN.linkinds(ttn, edge)) for edge in adjacentedges(g, prev_vertex) if edge != center_edge] + center_flag_index = tags(first(ITN.linkinds(ttn, center_edge))) + + if all(flags[tags(e)] == 1 for e in incomings) && center_flag_index != origin_flag_index + flags[center_flag_index] = 1 + end + + if next_vertex == p + ttn[next_vertex] = u * s + ttn[prev_vertex] = v + end + center_edge = center_edge_ + + end + + new_edge_list = [[src(e), dst(e)] for e in edges(ttn)] + if sweep > 1 && Set(edge_list) == Set(new_edge_list) + break + end + edge_list = new_edge_list + end + return ttn.tensornetwork.data_graph.underlying_graph, original_entanglements, final_entanglements +end + +function propose_structure(ψ, siteindices, max_degree::Int) + s1_ind = siteindices[1] + s2_ind = siteindices[2] + + remain_inds = filter(ind -> ind != s1_ind && ind != s2_ind, inds(ψ)) + n = length(remain_inds) + + entanglements = Float64[] + leftinds_list = [] + for k = 0:n + if k <= max_degree && (n-k) <= max_degree + for left in combinations(remain_inds, k) + leftinds = [s1_ind; left] + _, s, _ = svd(ψ, leftinds, cutoff = 0.0) + ee = entanglements_entropy(s) + push!(entanglements, ee) + push!(leftinds_list, leftinds) + end + end + end + return entanglements, leftinds_list +end + +function decide_structure(entanglements, nowsweep, nsweeps; T0::Float64 = 0.0) + if T0 > 0.0 + T = T0 * nowsweep / nsweeps + p = exp.(-entanglements / T) + p = p / sum(p) + index = sample(1:length(entanglements), Weights(p)) + else + index = argmin(entanglements) + end + return index +end + +function convert_ITensorNetwork(ttn::TreeTensorNetwork, ortho_vertex::Int=1) + g = ttn.tensornetwork.data_graph.underlying_graph + siteinds = [] + edge_inds = Dict{NamedEdge, ITensors.Index}() + itensors = ITensors.ITensor[] + for v in vertices(g) + tns = ttn.tensornetwork.data_graph[v] + tns_data = tns.data + tns_inds = tns.indices + # Add site index + site = ITensors.Index(tns_inds[1].dim; tags="s$v") + push!(siteinds, site) + # Add edge indices + inds = [] + for ind in tns_inds[2:end] + tag = ind.name + parts = split(tag, "=>") + src, dst = parse(Int, parts[1]), parse(Int, parts[2]) + e = NamedEdge(src, dst) + # Get edge ID from mapping + edge_id = findfirst(edge -> edge == e, collect(edges(g))) + ind = get!(edge_inds, NamedEdge(src, dst), ITensors.Index(ind.dim; tags="e$edge_id")) + push!(inds, ind) + end + # Add itensor + itensor = ITensors.ITensor(tns_data, [site; inds]) + push!(itensors, itensor) + end + + # ortho normalize + state = namedgraph_dijkstra_shortest_paths(g, ortho_vertex) + distances = state.dists + max_distance = maximum(distances[v] for v in vertices(g)) + for d = max_distance:-1:1 + children = filter(v -> distances[v] == d, vertices(g)) + for child in children + parent = state.parents[child] + ψ = ITensors.contract(itensors[parent], itensors[child]) + virtualindex = commonind(itensors[parent], itensors[child]) + child_inds = inds(itensors[child]) + left_inds = filter(i -> i != virtualindex, child_inds) + u, s, v = svd(ψ, left_inds, maxdim = dim(virtualindex); lefttags=tags(virtualindex), righttags=tags(virtualindex)) + v = v * s + itensors[child] = u + itensors[parent] = v + end + end + + ttn = ITN.ITensorNetwork(itensors) + ttn = ITN.TreeTensorNetwork(ttn, ortho_region=vertices(ttn)[ortho_vertex]) + return ttn +end \ No newline at end of file