diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2251642 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +Manifest.toml \ No newline at end of file diff --git a/Project.toml b/Project.toml index 716a020..bccfef1 100644 --- a/Project.toml +++ b/Project.toml @@ -6,9 +6,8 @@ version = "0.1.0" [deps] DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" -JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" -Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SimpleTensorNetworks = "3075f829-f72e-4896-a859-7fe0a9cabb9b" TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" @@ -16,6 +15,7 @@ TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" DataGraphs = "0.2.5" Graphs = "1.12.0" NamedGraphs = "0.6.4" +Random = "1.10" SimpleTensorNetworks = "0.1.0" TensorCrossInterpolation = "0.9.13" diff --git a/docs/make.jl b/docs/make.jl index 6f10cfa..d1a07aa 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -12,10 +12,7 @@ makedocs(; edit_link = "main", assets = String[], ), - pages = [ - "Home" => "index.md", - "API Reference" => "api.md" - ], + pages = ["Home" => "index.md", "API Reference" => "api.md"], ) deploydocs(; repo = "github.com/tensor4all/TreeTCI.jl.git", devbranch = "main") diff --git a/samples/devtci.jl b/samples/devtci.jl deleted file mode 100644 index cfb77d6..0000000 --- a/samples/devtci.jl +++ /dev/null @@ -1,26 +0,0 @@ -using Revise -using TreeTCI -using NamedGraphs: NamedGraph, add_edge!, edges - -function main() - localdims = fill(2, 7) - g = NamedGraph(7) - add_edge!(g, 1, 2) - add_edge!(g, 2, 3) - add_edge!(g, 2, 4) - add_edge!(g, 4, 5) - add_edge!(g, 5, 6) - add_edge!(g, 5, 7) - - f(v) = 1 / (1 + v' * v) - tolerance = 1e-8 - - mpn, ranks, errors = TreeTCI.TCI.crossinterpolate2(Float64, f, localdims; tolerance = tolerance) - ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g) - @show f([1, 1, 1, 1, 2, 1, 1]), f([1, 2, 1, 2, 2, 1, 1]), f([2, 2, 2, 2, 2, 2, 2]) - @show mpn([1, 1, 1, 1, 2, 1, 1]), mpn([1, 2, 1, 2, 2, 1, 1]), mpn([2, 2, 2, 2, 2, 2, 2]) - @show ttn([1, 1, 1, 1, 2, 1, 1]), ttn([1, 2, 1, 2, 2, 1, 1]), ttn([2, 2, 2, 2, 2, 2, 2]) - nothing -end - -main() diff --git a/src/TreeTCI.jl b/src/TreeTCI.jl index aab7804..8fc82a2 100644 --- a/src/TreeTCI.jl +++ b/src/TreeTCI.jl @@ -2,13 +2,25 @@ module TreeTCI import Graphs import NamedGraphs: - NamedGraph, NamedEdge, is_directed, outneighbors, has_edge, edges, vertices, src, dst, namedgraph_dijkstra_shortest_paths + NamedGraph, + NamedEdge, + is_directed, + outneighbors, + has_edge, + edges, + vertices, + src, + dst, + namedgraph_dijkstra_shortest_paths import TensorCrossInterpolation as TCI -import SimpleTensorNetworks: TensorNetwork, IndexedArray, Index, complete_contraction, getindex, contract -include("tree_utils.jl") +import SimpleTensorNetworks: + TensorNetwork, IndexedArray, Index, complete_contraction, getindex, contract +import Random: shuffle +include("treegraph_utils.jl") include("simpletci.jl") -include("simpletci_utils.jl") include("pivotcandidateproper.jl") include("sweep2sitepathproper.jl") +include("simpletci_optimize.jl") +include("simpletci_tensors.jl") include("treetensornetwork.jl") end diff --git a/src/abstracttreetensornetwork.jl b/src/abstracttreetensornetwork.jl deleted file mode 100644 index 0a83d47..0000000 --- a/src/abstracttreetensornetwork.jl +++ /dev/null @@ -1,40 +0,0 @@ -abstract type AbstractTreeTensorNetwork{V} <: Function end - -""" - function evaluate( - ttn::TreeTensorNetwork{V}, - indexset::Union{AbstractVector{Int}, NTuple{N, Int}} - )::V where {V} - -Evaluates the tensor train `tt` at indices given by `indexset`. -""" -function evaluate( - ttn::AbstractTreeTensorNetwork{V}, - indexset::Union{AbstractVector{Int},NTuple{N,Int}}, -)::V where {N,V} - if length(indexset) != length(ttn.sitetensors) - throw( - ArgumentError( - "To evaluate a tt of length $(length(ttn)), you have to provide $(length(ttn)) indices, but there were $(length(indexset)).", - ), - ) - end - sitetensors = IndexedArray[] - # TODO: site tensorを作る関数を作成 - for (Tinfo, i) in zip(ttn.sitetensors, indexset) - T, edges = Tinfo - inds = (i, ntuple(_ -> :, ndims(T) - 1)...) - T = T[inds...] - indexs = [ - Index(size(T)[j], "$(src(edges[j]))=>$(dst(edges[j]))") for j = 1:length(edges) - ] - t = IndexedArray(T, indexs) - push!(sitetensors, t) - end - tn = TensorNetwork(sitetensors) - return only(complete_contraction(tn)) -end - -function (ttn::AbstractTreeTensorNetwork{V})(indexset) where {V} - return evaluate(ttn, indexset) -end diff --git a/src/pivotcandidateproper.jl b/src/pivotcandidateproper.jl index 9c4db9c..7ec8659 100644 --- a/src/pivotcandidateproper.jl +++ b/src/pivotcandidateproper.jl @@ -1,12 +1,12 @@ """ Abstract type for pivot candidate generation strategies """ -abstract type PivotCandidateProper end +abstract type AbstractPivotCandidateProper end """ Default strategy that uses kronecker product and union with extra indices """ -struct DefaultPivotCandidateProper <: PivotCandidateProper end +struct DefaultPivotCandidateProper <: AbstractPivotCandidateProper end """ Default strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors @@ -22,10 +22,10 @@ function generate_pivot_candidates( Ikey, subIkey = subtreevertices(tci.g, vq => vp), vp Jkey, subJkey = subtreevertices(tci.g, vp => vq), vq - adjacent_edges_vp = adjacentedges(tci.g, vp; combinededges=edge) + adjacent_edges_vp = adjacentedges(tci.g, vp; combinededges = edge) InIkeys = edgeInIJkeys(tci.g, vp, adjacent_edges_vp) - adjacent_edges_vq = adjacentedges(tci.g, vq; combinededges=edge) + adjacent_edges_vq = adjacentedges(tci.g, vq; combinededges = edge) InJkeys = edgeInIJkeys(tci.g, vq, adjacent_edges_vq) # Generate base index sets for both sides diff --git a/src/simpletci.jl b/src/simpletci.jl index 0299bad..a4c41d6 100644 --- a/src/simpletci.jl +++ b/src/simpletci.jl @@ -1,8 +1,6 @@ MultiIndex = Vector{Int} SubTreeVertex = Vector{Int} -using Base: SimpleLogger - mutable struct SimpleTCI{ValueType} IJset::Dict{SubTreeVertex,Vector{MultiIndex}} localdims::Vector{Int} @@ -51,9 +49,9 @@ function SimpleTCI{ValueType}( return tci end -@doc""" -Add global pivots to index sets -""" +@doc """ + Add global pivots to index sets + """ function addglobalpivots!( tci::SimpleTCI{ValueType}, pivots::Vector{MultiIndex}, @@ -83,302 +81,6 @@ function addglobalpivots!( nothing end -@doc""" - optimize!(tci::SimpleTCI{ValueType}, f; kwargs...) - -Optimize the tensor cross interpolation (TCI) by iteratively updating pivots. - -# Arguments -- `tci`: The SimpleTCI object to optimize -- `f`: The function to interpolate - -# Keywords -- `tolerance::Union{Float64,Nothing} = nothing`: Error tolerance for convergence -- `pivottolerance::Union{Float64,Nothing} = nothing`: Deprecated, use tolerance instead -- `maxbonddim::Int = typemax(Int)`: Maximum bond dimension -- `maxiter::Int = 20`: Maximum number of iterations -- `sweepstrategy::Symbol = :backandforth`: Strategy for sweeping -- `pivotsearch::Symbol = :full`: Strategy for pivot search -- `verbosity::Int = 0`: Verbosity level -- `loginterval::Int = 10`: Interval for logging -- `normalizeerror::Bool = true`: Whether to normalize errors -- `ncheckhistory::Int = 3`: Number of history steps to check -- `maxnglobalpivot::Int = 5`: Maximum number of global pivots -- `nsearchglobalpivot::Int = 5`: Number of global pivots to search -- `tolmarginglobalsearch::Float64 = 10.0`: Tolerance margin for global search -- `strictlynested::Bool = false`: Whether to enforce strict nesting -- `checkbatchevaluatable::Bool = false`: Whether to check if function is batch evaluatable - -# Returns -- `ranks`: Vector of ranks at each iteration -- `errors`: Vector of normalized errors at each iteration -""" -function optimize!( - tci::SimpleTCI{ValueType}, - f; - tolerance::Union{Float64,Nothing} = nothing, - pivottolerance::Union{Float64,Nothing} = nothing, - maxbonddim::Int = typemax(Int), - maxiter::Int = 20, - sweepstrategy::Symbol = :backandforth, # TODO: Implement for Tree structure - pivotsearch::Symbol = :full, - verbosity::Int = 0, - loginterval::Int = 10, - normalizeerror::Bool = true, - ncheckhistory::Int = 3, - maxnglobalpivot::Int = 5, - nsearchglobalpivot::Int = 5, - tolmarginglobalsearch::Float64 = 10.0, - strictlynested::Bool = false, - checkbatchevaluatable::Bool = false, -) where {ValueType} - errors = Float64[] - ranks = Int[] - nglobalpivots = Int[] - local tol::Float64 - - if checkbatchevaluatable && !(f isa BatchEvaluator) - error("Function `f` is not batch evaluatable") - end - - if nsearchglobalpivot > 0 && nsearchglobalpivot < maxnglobalpivot - error("nsearchglobalpivot < maxnglobalpivot!") - end - - # Deprecate the pivottolerance option - if !isnothing(pivottolerance) - if !isnothing(tolerance) && (tolerance != pivottolerance) - throw( - ArgumentError( - "Got different values for pivottolerance and tolerance in optimize!(TCI2). For TCI2, both of these options have the same meaning. Please assign only `tolerance`.", - ), - ) - else - @warn "The option `pivottolerance` of `optimize!(tci::TensorCI2, f)` is deprecated. Please update your code to use `tolerance`, as `pivottolerance` will be removed in the future." - tol = pivottolerance - end - elseif !isnothing(tolerance) - tol = tolerance - else # pivottolerance == tolerance == nothing, therefore set tol to default value - tol = 1e-8 - end - - tstart = time_ns() - - if maxbonddim >= typemax(Int) && tol <= 0 - throw( - ArgumentError( - "Specify either tolerance > 0 or some maxbonddim; otherwise, the convergence criterion is not reachable!", - ), - ) - end - - globalpivots = MultiIndex[] - for iter = 1:maxiter - errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 - abstol = tol * errornormalization - - if verbosity > 1 - println(" Walltime $(1e-9*(time_ns() - tstart)) sec: starting 2site sweep") - flush(stdout) - end - - sweep2site!( - tci, - f, - 2; - iter1 = 1, - abstol = abstol, - maxbonddim = maxbonddim, - pivotsearch = pivotsearch, - verbosity = verbosity, - sweepstrategy = sweepstrategy, - ) - if verbosity > 0 && length(globalpivots) > 0 && mod(iter, loginterval) == 0 - abserr = [abs(evaluate(tci, p) - f(p)) for p in globalpivots] - nrejections = length(abserr .> abstol) - if nrejections > 0 - println( - " Rejected $(nrejections) global pivots added in the previous iteration, errors are $(abserr)", - ) - flush(stdout) - end - end - push!(errors, last(pivoterror(tci))) - - if verbosity > 1 - println( - " Walltime $(1e-9*(time_ns() - tstart)) sec: start searching global pivots", - ) - flush(stdout) - end - end - - errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 - return ranks, errors ./ errornormalization -end - -@doc""" -Perform 2site sweeps on a SimpleTCI. -!TODO: Implement for Tree structure - -""" -function sweep2site!( - tci::SimpleTCI{ValueType}, - f, - niter::Int; - iter1::Int = 1, - abstol::Float64 = 1e-8, - maxbonddim::Int = typemax(Int), - sweepstrategy::Symbol = :backandforth, - pivotsearch::Symbol = :full, - verbosity::Int = 0, -) where {ValueType} - - edge_path = generate_sweep2site_path(DefaultSweep2sitePathProper(), tci) - - - for iter = iter1:iter1+niter-1 - extraIJset = Dict(key => MultiIndex[] for key in keys(tci.IJset)) - if length(tci.IJset_history) > 0 - extraIJset = tci.IJset_history[end] - end - - push!(tci.IJset_history, deepcopy(tci.IJset)) - - flushpivoterror!(tci) - - for edge in edge_path - updatepivots!( - tci, - edge, - f; - abstol = abstol, - maxbonddim = maxbonddim, - verbosity = verbosity, - extraIJset = extraIJset, - ) - end - end - - nothing -end - - -function flushpivoterror!(tci::SimpleTCI{ValueType}) where {ValueType} - tci.pivoterrors = Float64[] - nothing -end - -""" -Update pivots at bond `b` of `tci` using the TCI2 algorithm. -Site tensors will be invalidated. -""" -function updatepivots!( - tci::SimpleTCI{ValueType}, - edge::NamedEdge, - f::F; - reltol::Float64 = 1e-14, - abstol::Float64 = 0.0, - maxbonddim::Int = typemax(Int), - verbosity::Int = 0, - extraIJset::Dict{SubTreeVertex,Vector{MultiIndex}} = Dict{ - SubTreeVertex, - Vector{MultiIndex}, - }(), - ) where {F,ValueType} - - N = length(tci.localdims) - - (IJkey, combinedIJset) = generate_pivot_candidates( - DefaultPivotCandidateProper(), - tci, - edge, - extraIJset, - ) - Ikey, Jkey = first(IJkey), last(IJkey) - - t1 = time_ns() - Pi = reshape( - filltensor( - ValueType, - f, - tci.localdims, - combinedIJset, - [Ikey], - [Jkey], - Val(0), - ), - length(combinedIJset[Ikey]), - length(combinedIJset[Jkey]), - ) - t2 = time_ns() - - updatemaxsample!(tci, Pi) - - luci = TCI.MatrixLUCI(Pi, reltol = reltol, abstol = abstol, maxrank = maxbonddim) - # TODO: we will implement luci according to optimal index subsets by following step - # 1. Compute the optimal index subsets (We also need the indices to set new pivots) - # 2. Reshape the Pi matrix by the optimal index subsets - # 3. Compute the LUCI by the reshaped Pi matrix - - t3 = time_ns() - if verbosity > 2 - x, y = length(combinedIJset[Ikey]), length(combinedIJset[Jkey]), - println( - " Computing Pi ($x x $y) at bond $b: $(1e-9*(t2-t1)) sec, LU: $(1e-9*(t3-t2)) sec", - ) - end - - tci.IJset[Ikey] = combinedIJset[Ikey][TCI.rowindices(luci)] - tci.IJset[Jkey] = combinedIJset[Jkey][TCI.colindices(luci)] - - updateerrors!(tci, edge, TCI.pivoterrors(luci)) - nothing - end - - -function updatemaxsample!(tci::SimpleTCI{V}, samples::Array{V}) where {V} - tci.maxsamplevalue = TCI.maxabs(tci.maxsamplevalue, samples) -end - -function updateerrors!( - tci::SimpleTCI{T}, - edge::NamedEdge, - errors::AbstractVector{Float64}, -) where {T} - updateedgeerror!(tci, edge, last(errors)) - updatepivoterror!(tci, errors) - nothing -end - -function updateedgeerror!( - tci::SimpleTCI{T}, - edge::NamedEdge, - error::Float64, -) where {T} - tci.bonderrors[edge] = error - nothing -end - -function updatepivoterror!(tci::SimpleTCI{T}, errors::AbstractVector{Float64}) where {T} - erroriter = Iterators.map(max, TCI.padzero(tci.pivoterrors), TCI.padzero(errors)) - tci.pivoterrors = - Iterators.take(erroriter, max(length(tci.pivoterrors), length(errors))) |> collect - nothing -end - -function pivoterror(tci::SimpleTCI{T}) where {T} - return maxbonderror(tci) -end - -function maxbonderror(tci::SimpleTCI{T}) where {T} - return maximum(values(tci.bonderrors)) -end - -""" -Return if site tensors are available -""" function pushunique!(collection, item) if !(item in collection) @@ -391,3 +93,4 @@ function pushunique!(collection, items...) pushunique!(collection, item) end end + diff --git a/src/simpletci_optimize.jl b/src/simpletci_optimize.jl new file mode 100644 index 0000000..7c4d77a --- /dev/null +++ b/src/simpletci_optimize.jl @@ -0,0 +1,276 @@ +@doc """ + optimize!(tci::SimpleTCI{ValueType}, f; kwargs...) + + Optimize the tensor cross interpolation (TCI) by iteratively updating pivots. + + # Arguments + - `tci`: The SimpleTCI object to optimize + - `f`: The function to interpolate + + # Keywords + - `tolerance::Union{Float64,Nothing} = nothing`: Error tolerance for convergence + - `pivottolerance::Union{Float64,Nothing} = nothing`: Deprecated, use tolerance instead + - `maxbonddim::Int = typemax(Int)`: Maximum bond dimension + - `maxiter::Int = 20`: Maximum number of iterations + - `sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper()`: Strategy for sweeping + - `pivotsearch::Symbol = :full`: Strategy for pivot search + - `verbosity::Int = 0`: Verbosity level + - `loginterval::Int = 10`: Interval for logging + - `normalizeerror::Bool = true`: Whether to normalize errors + - `ncheckhistory::Int = 3`: Number of history steps to check + - `maxnglobalpivot::Int = 5`: Maximum number of global pivots + - `nsearchglobalpivot::Int = 5`: Number of global pivots to search + - `tolmarginglobalsearch::Float64 = 10.0`: Tolerance margin for global search + - `strictlynested::Bool = false`: Whether to enforce strict nesting + - `checkbatchevaluatable::Bool = false`: Whether to check if function is batch evaluatable + + # Returns + - `ranks`: Vector of ranks at each iteration + - `errors`: Vector of normalized errors at each iteration + """ +function optimize!( + tci::SimpleTCI{ValueType}, + f; + tolerance::Union{Float64,Nothing} = nothing, + pivottolerance::Union{Float64,Nothing} = nothing, + maxbonddim::Int = typemax(Int), + maxiter::Int = 20, + sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper(), + pivotsearch::Symbol = :full, + verbosity::Int = 0, + loginterval::Int = 10, + normalizeerror::Bool = true, + ncheckhistory::Int = 3, + maxnglobalpivot::Int = 5, + nsearchglobalpivot::Int = 5, + tolmarginglobalsearch::Float64 = 10.0, + strictlynested::Bool = false, + checkbatchevaluatable::Bool = false, +) where {ValueType} + errors = Float64[] + ranks = Int[] + nglobalpivots = Int[] + local tol::Float64 + + if checkbatchevaluatable && !(f isa BatchEvaluator) + error("Function `f` is not batch evaluatable") + end + + if nsearchglobalpivot > 0 && nsearchglobalpivot < maxnglobalpivot + error("nsearchglobalpivot < maxnglobalpivot!") + end + + # Deprecate the pivottolerance option + if !isnothing(pivottolerance) + if !isnothing(tolerance) && (tolerance != pivottolerance) + throw( + ArgumentError( + "Got different values for pivottolerance and tolerance in optimize!(TCI2). For TCI2, both of these options have the same meaning. Please assign only `tolerance`.", + ), + ) + else + @warn "The option `pivottolerance` of `optimize!(tci::TensorCI2, f)` is deprecated. Please update your code to use `tolerance`, as `pivottolerance` will be removed in the future." + tol = pivottolerance + end + elseif !isnothing(tolerance) + tol = tolerance + else # pivottolerance == tolerance == nothing, therefore set tol to default value + tol = 1e-8 + end + + tstart = time_ns() + + if maxbonddim >= typemax(Int) && tol <= 0 + throw( + ArgumentError( + "Specify either tolerance > 0 or some maxbonddim; otherwise, the convergence criterion is not reachable!", + ), + ) + end + + globalpivots = MultiIndex[] + for iter = 1:maxiter + errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 + abstol = tol * errornormalization + + if verbosity > 1 + println(" Walltime $(1e-9*(time_ns() - tstart)) sec: starting 2site sweep") + flush(stdout) + end + + sweep2site!( + tci, + f, + 2; + iter1 = 1, + abstol = abstol, + maxbonddim = maxbonddim, + pivotsearch = pivotsearch, + verbosity = verbosity, + sweepstrategy = sweepstrategy, + ) + if verbosity > 0 && length(globalpivots) > 0 && mod(iter, loginterval) == 0 + abserr = [abs(evaluate(tci, p) - f(p)) for p in globalpivots] + nrejections = length(abserr .> abstol) + if nrejections > 0 + println( + " Rejected $(nrejections) global pivots added in the previous iteration, errors are $(abserr)", + ) + flush(stdout) + end + end + push!(errors, last(pivoterror(tci))) + + if verbosity > 1 + println( + " Walltime $(1e-9*(time_ns() - tstart)) sec: start searching global pivots", + ) + flush(stdout) + end + end + + errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 + return ranks, errors ./ errornormalization +end + +@doc """ + Perform 2site sweeps on a SimpleTCI. + !TODO: Implement for Tree structure + + """ +function sweep2site!( + tci::SimpleTCI{ValueType}, + f, + niter::Int; + iter1::Int = 1, + abstol::Float64 = 1e-8, + maxbonddim::Int = typemax(Int), + sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper(), + pivotsearch::Symbol = :full, + verbosity::Int = 0, +) where {ValueType} + + edge_path = generate_sweep2site_path(sweepstrategy, tci) + + for iter = iter1:iter1+niter-1 + extraIJset = Dict(key => MultiIndex[] for key in keys(tci.IJset)) + if length(tci.IJset_history) > 0 + extraIJset = tci.IJset_history[end] + end + + push!(tci.IJset_history, deepcopy(tci.IJset)) + + flushpivoterror!(tci) + + for edge in edge_path + updatepivots!( + tci, + edge, + f; + abstol = abstol, + maxbonddim = maxbonddim, + verbosity = verbosity, + extraIJset = extraIJset, + ) + end + end + + nothing +end + + +function flushpivoterror!(tci::SimpleTCI{ValueType}) where {ValueType} + tci.pivoterrors = Float64[] + nothing +end + +""" +Update pivots at bond `b` of `tci` using the TCI2 algorithm. +Site tensors will be invalidated. +""" +function updatepivots!( + tci::SimpleTCI{ValueType}, + edge::NamedEdge, + f::F; + reltol::Float64 = 1e-14, + abstol::Float64 = 0.0, + maxbonddim::Int = typemax(Int), + verbosity::Int = 0, + extraIJset::Dict{SubTreeVertex,Vector{MultiIndex}} = Dict{ + SubTreeVertex, + Vector{MultiIndex}, + }(), +) where {F,ValueType} + + N = length(tci.localdims) + + (IJkey, combinedIJset) = + generate_pivot_candidates(DefaultPivotCandidateProper(), tci, edge, extraIJset) + Ikey, Jkey = first(IJkey), last(IJkey) + + t1 = time_ns() + Pi = reshape( + filltensor(ValueType, f, tci.localdims, combinedIJset, [Ikey], [Jkey], Val(0)), + length(combinedIJset[Ikey]), + length(combinedIJset[Jkey]), + ) + t2 = time_ns() + + updatemaxsample!(tci, Pi) + + luci = TCI.MatrixLUCI(Pi, reltol = reltol, abstol = abstol, maxrank = maxbonddim) + # TODO: we will implement luci according to optimal index subsets by following step + # 1. Compute the optimal index subsets (We also need the indices to set new pivots) + # 2. Reshape the Pi matrix by the optimal index subsets + # 3. Compute the LUCI by the reshaped Pi matrix + + t3 = time_ns() + if verbosity > 2 + x, y = length(combinedIJset[Ikey]), + length(combinedIJset[Jkey]), + println( + " Computing Pi ($x x $y) at bond $b: $(1e-9*(t2-t1)) sec, LU: $(1e-9*(t3-t2)) sec", + ) + end + + tci.IJset[Ikey] = combinedIJset[Ikey][TCI.rowindices(luci)] + tci.IJset[Jkey] = combinedIJset[Jkey][TCI.colindices(luci)] + + updateerrors!(tci, edge, TCI.pivoterrors(luci)) + nothing +end + + +function updatemaxsample!(tci::SimpleTCI{V}, samples::Array{V}) where {V} + tci.maxsamplevalue = TCI.maxabs(tci.maxsamplevalue, samples) +end + +function updateerrors!( + tci::SimpleTCI{T}, + edge::NamedEdge, + errors::AbstractVector{Float64}, +) where {T} + updateedgeerror!(tci, edge, last(errors)) + updatepivoterror!(tci, errors) + nothing +end + +function updateedgeerror!(tci::SimpleTCI{T}, edge::NamedEdge, error::Float64) where {T} + tci.bonderrors[edge] = error + nothing +end + +function updatepivoterror!(tci::SimpleTCI{T}, errors::AbstractVector{Float64}) where {T} + erroriter = Iterators.map(max, TCI.padzero(tci.pivoterrors), TCI.padzero(errors)) + tci.pivoterrors = + Iterators.take(erroriter, max(length(tci.pivoterrors), length(errors))) |> collect + nothing +end + +function pivoterror(tci::SimpleTCI{T}) where {T} + return maxbonderror(tci) +end + +function maxbonderror(tci::SimpleTCI{T}) where {T} + return maximum(values(tci.bonderrors)) +end diff --git a/src/simpletci_utils.jl b/src/simpletci_tensors.jl similarity index 88% rename from src/simpletci_utils.jl rename to src/simpletci_tensors.jl index af26b85..313cbdf 100644 --- a/src/simpletci_utils.jl +++ b/src/simpletci_tensors.jl @@ -4,7 +4,8 @@ function fillsitetensors( center_vertex::Int = 0, ) where {ValueType} - sitetensors = Vector{Pair{Array{ValueType},Vector{NamedEdge}}}(undef, length(vertices(tci.g))) + sitetensors = + Vector{Pair{Array{ValueType},Vector{NamedEdge}}}(undef, length(vertices(tci.g))) if center_vertex ∉ vertices(tci.g) center_vertex = first(vertices(tci.g)) @@ -18,17 +19,21 @@ function fillsitetensors( for child in children # adjacent_edges = adjacentedges(tci.g, child) parent = state.parents[child] - edge = filter(e -> src(e) == parent && dst(e) == child || dst(e) == parent && src(e) == child, edges(tci.g)) + edge = filter( + e -> + src(e) == parent && dst(e) == child || + dst(e) == parent && src(e) == child, + edges(tci.g), + ) edge = isempty(edge) ? nothing : only(edge) incomingedges = setdiff(adjacentedges(tci.g, child), Set([edge])) - InKeys = !isempty(incomingedges) ? edgeInIJkeys(tci.g, child, incomingedges) : SubTreeVertex[] + InKeys = + !isempty(incomingedges) ? edgeInIJkeys(tci.g, child, incomingedges) : + SubTreeVertex[] OutKeys = edge != nothing ? edgeInIJkeys(tci.g, child, edge) : SubTreeVertex[] if d != 0 T = sitetensor(tci, child, edge, InKeys => OutKeys, f) - sitetensors[child] = T => vcat( - incomingedges, - [edge], - ) + sitetensors[child] = T => vcat(incomingedges, [edge]) else T = sitetensor(tci, child, edge, InKeys => OutKeys, f, core = true) sitetensors[child] = T => incomingedges @@ -91,11 +96,11 @@ function sitetensor( length(tci.IJset[I1key]) == sum([length(tci.IJset[key]) for key in Outkeys]) || error("Pivot matrix at bond $(site) is not square!") Tmat = transpose(transpose(P) \ transpose(Pi1)) T = reshape( - Tmat, - tci.localdims[site], - [length(tci.IJset[key]) for key in Inkeys]..., - [length(tci.IJset[key]) for key in Outkeys]..., - ) + Tmat, + tci.localdims[site], + [length(tci.IJset[key]) for key in Inkeys]..., + [length(tci.IJset[key]) for key in Outkeys]..., + ) return T end @@ -186,11 +191,7 @@ function _call( return result end -function edgeInIJkeys( - g::NamedGraph, - v::Int, - combinededges -) +function edgeInIJkeys(g::NamedGraph, v::Int, combinededges) if combinededges isa NamedEdge combinededges = [combinededges] end @@ -204,4 +205,4 @@ function edgeInIJkeys( end end return keys -end \ No newline at end of file +end diff --git a/src/sweep2sitepathproper.jl b/src/sweep2sitepathproper.jl index 2351cbe..7b0e930 100644 --- a/src/sweep2sitepathproper.jl +++ b/src/sweep2sitepathproper.jl @@ -1,18 +1,48 @@ """ Abstract type for pivot candidate generation strategies """ -abstract type Sweep2sitePathProper end +abstract type AbstractSweep2sitePathProper end """ -Default strategy that uses kronecker product and union with extra indices +Default strategy """ -struct DefaultSweep2sitePathProper <: Sweep2sitePathProper end +struct DefaultSweep2sitePathProper <: AbstractSweep2sitePathProper end """ -Default strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors +Random strategy +""" +struct RandomSweep2sitePathProper <: AbstractSweep2sitePathProper end + +""" +LocalAdjacent strategy +""" +struct LocalAdjacentSweep2sitePathProper <: AbstractSweep2sitePathProper end + +""" +Default strategy that return the sequence path defined by the edges(g) """ function generate_sweep2site_path( ::DefaultSweep2sitePathProper, + tci::SimpleTCI{ValueType}, +) where {ValueType} + return collect(edges(tci.g)) +end + +""" +Random strategy that returns a random sequence of edges +""" +function generate_sweep2site_path( + ::RandomSweep2sitePathProper, + tci::SimpleTCI{ValueType}, +) where {ValueType} + return shuffle(collect(edges(tci.g))) +end + +""" +LocalAdjacent strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors +""" +function generate_sweep2site_path( + ::LocalAdjacentSweep2sitePathProper, tci::SimpleTCI{ValueType}; origin_edge = undef, ) where {ValueType} @@ -42,10 +72,7 @@ function generate_sweep2site_path( while true candidates = candidateedges(tci.g, center_edge) - candidates = filter( - e -> flags[e] == 0, - candidates - ) + candidates = filter(e -> flags[e] == 0, candidates) # If candidates is empty, exit while loop if isempty(candidates) @@ -74,4 +101,4 @@ function generate_sweep2site_path( end return edge_path -end \ No newline at end of file +end diff --git a/src/tree_utils.jl b/src/treegraph_utils.jl similarity index 64% rename from src/tree_utils.jl rename to src/treegraph_utils.jl index bb688f1..a01ace9 100644 --- a/src/tree_utils.jl +++ b/src/treegraph_utils.jl @@ -35,8 +35,8 @@ end function adjacentedges( g::NamedGraph, vertex::Int; - combinededges::Union{NamedEdge, Vector{NamedEdge}} = Vector{NamedEdge}() -) ::Vector{NamedEdge} + combinededges::Union{NamedEdge,Vector{NamedEdge}} = Vector{NamedEdge}(), +)::Vector{NamedEdge} if combinededges isa NamedEdge combinededges = [combinededges] end @@ -50,38 +50,32 @@ function adjacentedges( return adjedges end -function candidateedges( - g::NamedGraph, - edge::NamedEdge, -)::Vector{NamedEdge} +function candidateedges(g::NamedGraph, edge::NamedEdge)::Vector{NamedEdge} p, q = separatevertices(g, edge) - candidates = adjacentedges(g, p; combinededges=edge) ∪ adjacentedges(g, q; combinededges=edge) + candidates = + adjacentedges(g, p; combinededges = edge) ∪ + adjacentedges(g, q; combinededges = edge) return candidates end -function distanceedges( - g::NamedGraph, - edge::NamedEdge, -)::Dict{NamedEdge,Int} +function distanceedges(g::NamedGraph, edge::NamedEdge)::Dict{NamedEdge,Int} p, q = separatevertices(g, edge) distances = Dict{NamedEdge,Int}() - distances[edge] = 0 - distances = distanceBFSedge(g, edge, distances) - return distances -end -function distanceBFSedge( - g::NamedGraph, - edge::NamedEdge, - distances::Dict{NamedEdge,Int}, -)::Dict{NamedEdge,Int} - - candidates = candidateedges(g, edge) - candidates = filter(cand -> cand ∉ keys(distances), candidates) - for cand in candidates - distances[cand] = distances[edge] + 1 - distances = - merge!(distances, distanceBFSedge(g, cand, distances)) + function compute_distances(root, opposite, state) + for subvertex in filter(v -> v != root, subtreevertices(g, opposite => root)) + parent = state.parents[subvertex] + e = NamedEdge(parent, subvertex) + distances[e ∈ edges(g) ? e : reverse(e)] = state.dists[subvertex] + end end + + state_p = namedgraph_dijkstra_shortest_paths(g, p) + compute_distances(p, q, state_p) + + state_q = namedgraph_dijkstra_shortest_paths(g, q) + compute_distances(q, p, state_q) + + distances[edge] = 0 return distances end diff --git a/src/treetensornetwork.jl b/src/treetensornetwork.jl index 1e189b4..1109d46 100644 --- a/src/treetensornetwork.jl +++ b/src/treetensornetwork.jl @@ -2,7 +2,8 @@ mutable struct TreeTensorNetwork{ValueType} tensornetwork::TensorNetwork function TreeTensorNetwork( - g::NamedGraph, sitetensors::Vector{Pair{Array{ValueType},Vector{NamedEdge}}}, + g::NamedGraph, + sitetensors::Vector{Pair{Array{ValueType},Vector{NamedEdge}}}, ) where {ValueType} !Graphs.is_cyclic(g) || error("TreeTensorNetwork is not supported for loopy tensor network.") @@ -11,8 +12,9 @@ mutable struct TreeTensorNetwork{ValueType} indexs = vcat( Index(size(T)[1], "s$i"), [ - Index(size(T)[j+1], "$(src(edges[j]))=>$(dst(edges[j]))") for j = 1:length(edges) - ] + Index(size(T)[j+1], "$(src(edges[j]))=>$(dst(edges[j]))") for + j = 1:length(edges) + ], ) t = IndexedArray(T, indexs) push!(ttntensors, t) @@ -25,7 +27,7 @@ end function crossinterpolate( ::Type{ValueType}, f, - localdims::Union{Vector{Int}, NTuple{N,Int}}, + localdims::Union{Vector{Int},NTuple{N,Int}}, g::NamedGraph, initialpivots::Vector{MultiIndex} = [ones(Int, length(localdims))]; kwargs..., @@ -39,7 +41,7 @@ end function evaluate( ttn::TreeTensorNetwork{ValueType}, indexset::Union{AbstractVector{Int},NTuple{N,Int}}, - ) where {N, ValueType} +) where {N,ValueType} tn = deepcopy(ttn.tensornetwork) if length(indexset) != length(vertices(tn.data_graph)) throw( @@ -51,8 +53,8 @@ function evaluate( for i = 1:length(vertices(tn.data_graph)) t = tn[i] site = IndexedArray( - [j == indexset[i] ? 1.0 : 0.0 for j in 1:t.indices[1].dim], - [t.indices[1]] + [j == indexset[i] ? 1.0 : 0.0 for j = 1:t.indices[1].dim], + [t.indices[1]], ) tn[i] = contract(t, site) end @@ -62,4 +64,3 @@ end function (ttn::TreeTensorNetwork{V})(indexset) where {V} return evaluate(ttn, indexset) end - diff --git a/test/runtests.jl b/test/runtests.jl index 149b069..236fbbd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,6 @@ using Test # end @testset verbose = true "Actual tests" begin - include("simpletci_utils_test.jl") + include("simpletci_test.jl") end end diff --git a/test/simpletci_utils_test.jl b/test/simpletci_test.jl similarity index 54% rename from test/simpletci_utils_test.jl rename to test/simpletci_test.jl index 8d2a467..a073ff1 100644 --- a/test/simpletci_utils_test.jl +++ b/test/simpletci_test.jl @@ -1,6 +1,6 @@ using Test using TreeTCI -import NamedGraphs: NamedGraph, NamedEdge, add_edge!, edges, has_edge +import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge @testset "simpletci.jl" begin # make graph @@ -12,8 +12,7 @@ import NamedGraphs: NamedGraph, NamedEdge, add_edge!, edges, has_edge add_edge!(g, 5, 6) add_edge!(g, 5, 7) - - @testset "SubTreeVertex" begin + @testset "TreeGraphUtils" begin e = NamedEdge(2 => 4) v1, v2 = TreeTCI.separatevertices(g, e) @test v1 == 2 @@ -30,13 +29,11 @@ import NamedGraphs: NamedGraph, NamedEdge, add_edge!, edges, has_edge @test last(subregions) == [4, 5, 6, 7] - @test Set(TreeTCI.adjacentedges(g, 4)) == Set( - [NamedEdge(2 => 4), NamedEdge(4 => 5)] - ) + @test Set(TreeTCI.adjacentedges(g, 4)) == + Set([NamedEdge(2 => 4), NamedEdge(4 => 5)]) - @test Set(TreeTCI.candidateedges(g, NamedEdge(2 => 4))) == Set( - [NamedEdge(1 => 2), NamedEdge(2 => 3), NamedEdge(4 => 5)] - ) + @test Set(TreeTCI.candidateedges(g, NamedEdge(2 => 4))) == + Set([NamedEdge(1 => 2), NamedEdge(2 => 3), NamedEdge(4 => 5)]) @test TreeTCI.distanceedges(g, NamedEdge(2 => 4)) == Dict( NamedEdge(2 => 4) => 0, @@ -49,4 +46,16 @@ import NamedGraphs: NamedGraph, NamedEdge, add_edge!, edges, has_edge end + @testset "SimpleTCI" begin + localdims = fill(2, length(vertices(g))) + f(v) = 1 / (1 + v' * v) + + ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g) + @test ttn([1, 1, 1, 1, 1, 1, 1]) ≈ f([1, 1, 1, 1, 1, 1, 1]) + @test ttn([1, 1, 1, 1, 2, 2, 2]) ≈ f([1, 1, 1, 1, 2, 2, 2]) + @test ttn([2, 2, 2, 2, 1, 1, 1]) ≈ f([2, 2, 2, 2, 1, 1, 1]) + @test ttn([1, 2, 1, 2, 1, 2, 1]) ≈ f([1, 2, 1, 2, 1, 2, 1]) + @test ttn([2, 1, 2, 1, 2, 1, 2]) ≈ f([2, 1, 2, 1, 2, 1, 2]) + @test ttn([2, 2, 2, 2, 2, 2, 2]) ≈ f([2, 2, 2, 2, 2, 2, 2]) + end end