diff --git a/.gitignore b/.gitignore index 2251642..b02ba6e 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -Manifest.toml \ No newline at end of file +Manifest.toml +samples/ \ No newline at end of file diff --git a/Project.toml b/Project.toml index fe106a9..bccfef1 100644 --- a/Project.toml +++ b/Project.toml @@ -14,7 +14,6 @@ TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" [compat] DataGraphs = "0.2.5" Graphs = "1.12.0" -JuliaFormatter = "1" NamedGraphs = "0.6.4" Random = "1.10" SimpleTensorNetworks = "0.1.0" diff --git a/src/pivotcandidateproposer.jl b/src/pivotcandidateproposer.jl index 0ac9bac..5a65ca2 100644 --- a/src/pivotcandidateproposer.jl +++ b/src/pivotcandidateproposer.jl @@ -8,6 +8,16 @@ Default strategy that uses kronecker product and union with extra indices """ struct DefaultPivotCandidateProposer <: AbstractPivotCandidateProposer end +""" +Truncated default strategy that uses kronecker product and union with extra indices +""" +struct TruncatedDefaultPivotCandidateProposer <: AbstractPivotCandidateProposer end + +""" +Simple strategy that uses kronecker product and union with extra indices +""" +struct SimplePivotCandidateProposer <: AbstractPivotCandidateProposer end + """ Default strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors """ @@ -15,34 +25,84 @@ function generate_pivot_candidates( ::DefaultPivotCandidateProposer, tci::SimpleTCI{ValueType}, edge::NamedEdge, - extraIJset::Dict{SubTreeVertex,Vector{MultiIndex}}, ) where {ValueType} vp, vq = separatevertices(tci.g, edge) - Ikey, subIkey = subtreevertices(tci.g, vq => vp), vp - Jkey, subJkey = subtreevertices(tci.g, vp => vq), vq + Ikey = subtreevertices(tci.g, vq => vp) + Jkey = subtreevertices(tci.g, vp => vq) adjacent_edges_vp = adjacentedges(tci.g, vp; combinededges = edge) InIkeys = edgeInIJkeys(tci.g, vp, adjacent_edges_vp) + Ipivots = pivotset(tci.IJset, InIkeys, Ikey, tci.localdims[vp]) + Isite_index = findfirst(==(vp), Ikey) adjacent_edges_vq = adjacentedges(tci.g, vq; combinededges = edge) InJkeys = edgeInIJkeys(tci.g, vq, adjacent_edges_vq) + Jpivots = pivotset(tci.IJset, InJkeys, Jkey, tci.localdims[vq]) + Jsite_index = findfirst(==(vq), Jkey) + + Iset = kronecker(Ipivots, Isite_index, tci.localdims[vp]) + Jset = kronecker(Jpivots, Jsite_index, tci.localdims[vq]) - # Generate base index sets for both sides - Iset = kronecker(tci.IJset, Ikey, InIkeys, vp, tci.localdims[vp]) - Jset = kronecker(tci.IJset, Jkey, InJkeys, vq, tci.localdims[vq]) + extraIJset = if length(tci.IJset_history) > 0 + extraIJset = tci.IJset_history[end] + else + Dict(key => MultiIndex[] for key in keys(tci.IJset)) + end - # Combine with extra indices if available Icombined = union(Iset, extraIJset[Ikey]) Jcombined = union(Jset, extraIJset[Jkey]) - return (Ikey => Jkey), Dict(Ikey => Icombined, Jkey => Jcombined) + return Dict(Ikey => Icombined, Jkey => Jcombined) end -function kronecker( +function generate_pivot_candidates( + ::TruncatedDefaultPivotCandidateProposer, + tci::SimpleTCI{ValueType}, + edge::NamedEdge, +) where {ValueType} + vp, vq = separatevertices(tci.g, edge) + + Ikey = subtreevertices(tci.g, vq => vp) + Jkey = subtreevertices(tci.g, vp => vq) + chis = Dict(Ikey => tci.localdims[vp] * length(tci.IJset[Ikey]), Jkey => tci.localdims[vq] * length(tci.IJset[Jkey])) + + IJcombined = generate_pivot_candidates(DefaultPivotCandidateProposer(), tci, edge) + IJcombined = Dict( + key => sample_ordered_pivots(IJcombined[key], chis[key]) for + key in keys(IJcombined) + ) + return IJcombined +end + +function generate_pivot_candidates( + ::SimplePivotCandidateProposer, + tci::SimpleTCI{ValueType}, + edge::NamedEdge, +) where {ValueType} + vp, vq = separatevertices(tci.g, edge) + + Ikey = subtreevertices(tci.g, vq => vp) + Ichi = tci.localdims[vp] * length(tci.IJset[Ikey]) + Iset = [[rand(1:tci.localdims[i]) for i in Ikey] for _ = 1:Ichi] + + Jkey = subtreevertices(tci.g, vp => vq) + Jchi = tci.localdims[vq] * length(tci.IJset[Jkey]) + Jset = [[rand(1:tci.localdims[j]) for j in Jkey] for _ = 1:Jchi] + extraIJset = if length(tci.IJset_history) > 0 + extraIJset = tci.IJset_history[end] + else + Dict(key => MultiIndex[] for key in keys(tci.IJset)) + end + Icombined = union(Iset, extraIJset[Ikey]) + Jcombined = union(Jset, extraIJset[Jkey]) + return Dict(Ikey => Icombined, Jkey => Jcombined) +end + + +function pivotset( IJset::Dict{SubTreeVertex,Vector{MultiIndex}}, + Inkeys::Vector{SubTreeVertex}, Outkey::SubTreeVertex, # original subregions order - Inkeys::Vector{SubTreeVertex}, # original subregions order - site::Int, # direct connected site localdim::Int, ) pivotset = MultiIndex[] @@ -56,16 +116,27 @@ function kronecker( end push!(pivotset, indexset) end + return pivotset +end - site_index = findfirst(==(site), Outkey) - filtered_subregions = filter(x -> x ≠ Set([site]), Outkey) - - if site_index === nothing - return MultiIndex[] +function sample_ordered_pivots(pivots::Vector{MultiIndex}, maxsize::Int) + n = length(pivots) + @show n, maxsize + if n ≤ maxsize + return pivots end + selected_indices = shuffle(1:n)[1:maxsize] + return pivots[sort(selected_indices)] +end +function kronecker( + pivotset::Vector{MultiIndex}, + site_index::Union{Int,Nothing}, + localdims::Int, +) + isnothing(site_index) && return MultiIndex[] return MultiIndex[ [is[1:site_index-1]..., j, is[site_index+1:end]...] for is in pivotset, - j = 1:localdim + j = 1:localdims ][:] end diff --git a/src/simpletci.jl b/src/simpletci.jl index 1464827..963d1dc 100644 --- a/src/simpletci.jl +++ b/src/simpletci.jl @@ -1,28 +1,58 @@ MultiIndex = Vector{Int} SubTreeVertex = Vector{Int} +@doc """ + SimpleTCI{ValueType} + +Tree tensor cross interpolation (TCI) for tree tensor networks. + +# Fields +- `IJset::Dict{SubTreeVertex,Vector{MultiIndex}}`: Pivots sets for each subtrees +- `localdims::Vector{Int}`: Local dimensions for each vertex tensor +- `g::NamedGraph`: Tree graph structure +- `bonderrors::Dict{NamedEdge,Float64}`: Error estimate per bond by 2-site sweep +- `pivoterrors::Vector{Float64}`: Error estimate for backtruncation of bonds +- `maxsamplevalue::Float64`: Maximum sample value for error normalization +- `IJset_history::Vector{Dict{SubTreeVertex,Vector{MultiIndex}}}`: History of pivots sets for each sweep + +# Example +```julia +# Create a simple tree graph +g = NamedGraph([1, 2, 3]) +add_edge!(g, 1 => 2) +add_edge!(g, 2 => 3) + +# Define local dimensions +localdims = [2, 2, 2] + +# Create a SimpleTCI instance +tci = SimpleTCI{Float64}(localdims, g) + +# Add initial pivots +addglobalpivots!(tci, [[1,1,1], [2,1,1]]) +``` +""" mutable struct SimpleTCI{ValueType} IJset::Dict{SubTreeVertex,Vector{MultiIndex}} localdims::Vector{Int} g::NamedGraph - #"Error estimate per bond by 2site sweep." - bonderrors::Dict{NamedEdge,Float64} # key is the bond id - # "Error estimate for backtruncation of bonds." - pivoterrors::Vector{Float64} # key is the bond id - #"Maximum sample for error normalization." + bonderrors::Dict{NamedEdge,Float64} + pivoterrors::Vector{Float64} maxsamplevalue::Float64 IJset_history::Vector{Dict{SubTreeVertex,Vector{MultiIndex}}} function SimpleTCI{ValueType}(localdims::Vector{Int}, g::NamedGraph) where {ValueType} - length(localdims) > 1 || error("localdims should have at least 2 elements!") n = length(localdims) + n > 1 || error("localdims should have at least 2 elements!") + n == length(vertices(g)) || error( + "The number of vertices in the graph must be equal to the length of localdims.", + ) + !Graphs.is_cyclic(g) || + error("SimpleTCI is not supported for loopy tensor network.") # assign the key for each bond bonderrors = Dict(e => 0.0 for e in edges(g)) - !Graphs.is_cyclic(g) || - error("TreeTensorNetwork is not supported for loopy tensor network.") - new{ValueType}( Dict{SubTreeVertex,Vector{MultiIndex}}(), # IJset localdims, @@ -35,6 +65,10 @@ mutable struct SimpleTCI{ValueType} end end +""" + Initialize a SimpleTCI instance with a function, local dimensions, and graph. + The initial grobal pivots are set to ones(Int, length(localdims)). +""" function SimpleTCI{ValueType}( func::F, localdims::Vector{Int}, @@ -49,14 +83,14 @@ function SimpleTCI{ValueType}( return tci end -@doc """ - Add global pivots to index sets - """ +""" + Add global pivots to IJset. +""" function addglobalpivots!( tci::SimpleTCI{ValueType}, pivots::Vector{MultiIndex}, ) where {ValueType} - if any(length(tci.localdims) .!= length.(pivots)) # AbstructTreeTensorNetworkをから引き継ぎlength(tci)ができると良い + if any(length(tci.localdims) .!= length.(pivots)) throw(DimensionMismatch("Please specify a pivot as one index per leg of the TTN.")) end for pivot in pivots @@ -71,17 +105,16 @@ function addglobalpivots!( if !haskey(tci.IJset, Jset_key) tci.IJset[Jset_key] = Vector{MultiIndex}() end - pushunique!(tci.IJset[Iset_key], [pivot[i] for i in Iset_key]) - pushunique!(tci.IJset[Jset_key], [pivot[j] for j in Jset_key]) + pushunique!(tci.IJset[Iset_key], MultiIndex([pivot[i] for i in Iset_key])) + pushunique!(tci.IJset[Jset_key], MultiIndex([pivot[j] for j in Jset_key])) end end - tci.IJset[[i for i = 1:length(tci.localdims)]] = Int[] + tci.IJset[[i for i = 1:length(tci.localdims)]] = MultiIndex[] nothing end - function pushunique!(collection, item) if !(item in collection) push!(collection, item) diff --git a/src/simpletci_optimize.jl b/src/simpletci_optimize.jl index c8aef8d..04a8f3f 100644 --- a/src/simpletci_optimize.jl +++ b/src/simpletci_optimize.jl @@ -1,26 +1,41 @@ -@doc """ - optimize!(tci::SimpleTCI{ValueType}, f; kwargs...) - - Optimize the tensor cross interpolation (TCI) by iteratively updating pivots. - - # Arguments - - `tci`: The SimpleTCI object to optimize - - `f`: The function to interpolate - - # Keywords - - `tolerance::Union{Float64,Nothing} = nothing`: Error tolerance for convergence - - `maxbonddim::Int = typemax(Int)`: Maximum bond dimension - - `maxiter::Int = 20`: Maximum number of iterations - - `sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer()`: Strategy for sweeping - - `verbosity::Int = 0`: Verbosity level - - `loginterval::Int = 10`: Interval for logging - - `normalizeerror::Bool = true`: Whether to normalize errors - - `ncheckhistory::Int = 3`: Number of history steps to check - - # Returns - - `ranks`: Vector of ranks at each iteration - - `errors`: Vector of normalized errors at each iteration - """ +@doc raw""" + optimize!( + tci::SimpleTCI{ValueType}, f; + tolerance::Union{Float64,Nothing} = nothing, + maxbonddim::Int = typemax(Int), + maxiter::Int = 20, + sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer(), + pivotstrategy::AbstractPivotCandidateProposer = DefaultPivotCandidateProposer(), + verbosity::Int = 0, + loginterval::Int = 10, + normalizeerror::Bool = true, + ncheckhistory::Int = 3, + ) + +Optimize the SimpleTCI instance by iteratively updating pivots. + +# Arguments +- `tci`: The SimpleTCI object to optimize +- `f`: The function to interpolate +- `tolerance::Union{Float64,Nothing} = nothing`: Error tolerance for convergence +- `maxbonddim::Int = typemax(Int)`: Maximum bond dimension +- `maxiter::Int = 20`: Maximum number of iterations +- `sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer()`: Strategy for sweeping +- `pivotstrategy::AbstractPivotCandidateProposer = DefaultPivotCandidateProposer()`: Strategy for proposing pivot candidates +- `verbosity::Int = 0`: Verbosity level +- `loginterval::Int = 10`: Interval for logging +- `normalizeerror::Bool = true`: Whether to normalize errors +- `ncheckhistory::Int = 3`: Number of history steps to check + +# Returns +- `ranks`: Vector of ranks at each iteration +- `errors`: Vector of normalized errors at each iteration + +# Note +- The SimpleTCI object will be modified in place. +- Set `tolerance` to be > 0 or `maxbonddim` to some reasonable value. Otherwise, convergence is not reachable. + +""" function optimize!( tci::SimpleTCI{ValueType}, f; @@ -28,6 +43,7 @@ function optimize!( maxbonddim::Int = typemax(Int), maxiter::Int = 20, sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer(), + pivotstrategy::AbstractPivotCandidateProposer = DefaultPivotCandidateProposer(), verbosity::Int = 0, loginterval::Int = 10, normalizeerror::Bool = true, @@ -64,6 +80,7 @@ function optimize!( maxbonddim = maxbonddim, verbosity = verbosity, sweepstrategy = sweepstrategy, + pivotstrategy = pivotstrategy, ) if verbosity > 0 && length(globalpivots) > 0 && mod(iter, loginterval) == 0 abserr = [abs(evaluate(tci, p) - f(p)) for p in globalpivots] @@ -89,11 +106,9 @@ function optimize!( return ranks, errors ./ errornormalization end -@doc """ +""" Perform 2site sweeps on a SimpleTCI. - !TODO: Implement for Tree structure - - """ +""" function sweep2site!( tci::SimpleTCI{ValueType}, f, @@ -101,6 +116,7 @@ function sweep2site!( abstol::Float64 = 1e-8, maxbonddim::Int = typemax(Int), sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer(), + pivotstrategy::AbstractPivotCandidateProposer = DefaultPivotCandidateProposer(), verbosity::Int = 0, ) where {ValueType} @@ -108,9 +124,6 @@ function sweep2site!( for _ = 1:niter extraIJset = Dict(key => MultiIndex[] for key in keys(tci.IJset)) - if length(tci.IJset_history) > 0 - extraIJset = tci.IJset_history[end] - end push!(tci.IJset_history, deepcopy(tci.IJset)) @@ -123,8 +136,8 @@ function sweep2site!( f; abstol = abstol, maxbonddim = maxbonddim, + pivotstrategy = pivotstrategy, verbosity = verbosity, - extraIJset = extraIJset, ) end end @@ -132,15 +145,8 @@ function sweep2site!( nothing end - -function flushpivoterror!(tci::SimpleTCI{ValueType}) where {ValueType} - tci.pivoterrors = Float64[] - nothing -end - """ -Update pivots at bond `b` of `tci` using the TCI2 algorithm. -Site tensors will be invalidated. + Update pivots at bond of tci object. """ function updatepivots!( tci::SimpleTCI{ValueType}, @@ -149,18 +155,15 @@ function updatepivots!( reltol::Float64 = 1e-14, abstol::Float64 = 0.0, maxbonddim::Int = typemax(Int), + pivotstrategy::AbstractPivotCandidateProposer = DefaultPivotCandidateProposer(), verbosity::Int = 0, - extraIJset::Dict{SubTreeVertex,Vector{MultiIndex}} = Dict{ - SubTreeVertex, - Vector{MultiIndex}, - }(), ) where {F,ValueType} N = length(tci.localdims) - (IJkey, combinedIJset) = - generate_pivot_candidates(DefaultPivotCandidateProposer(), tci, edge, extraIJset) - Ikey, Jkey = first(IJkey), last(IJkey) + combinedIJset = generate_pivot_candidates(pivotstrategy, tci, edge) + keys_array = collect(keys(combinedIJset)) + Ikey, Jkey = first(keys_array), last(keys_array) t1 = time_ns() Pi = reshape( @@ -173,10 +176,6 @@ function updatepivots!( updatemaxsample!(tci, Pi) luci = TCI.MatrixLUCI(Pi, reltol = reltol, abstol = abstol, maxrank = maxbonddim) - # TODO: we will implement luci according to optimal index subsets by following step - # 1. Compute the optimal index subsets (We also need the indices to set new pivots) - # 2. Reshape the Pi matrix by the optimal index subsets - # 3. Compute the LUCI by the reshaped Pi matrix t3 = time_ns() if verbosity > 2 @@ -194,7 +193,6 @@ function updatepivots!( nothing end - function updatemaxsample!(tci::SimpleTCI{V}, samples::Array{V}) where {V} tci.maxsamplevalue = TCI.maxabs(tci.maxsamplevalue, samples) end @@ -209,6 +207,11 @@ function updateerrors!( nothing end +function flushpivoterror!(tci::SimpleTCI{ValueType}) where {ValueType} + tci.pivoterrors = Float64[] + nothing +end + function updateedgeerror!(tci::SimpleTCI{T}, edge::NamedEdge, error::Float64) where {T} tci.bonderrors[edge] = error nothing diff --git a/src/simpletci_tensors.jl b/src/simpletci_tensors.jl index 313cbdf..a5d9ea7 100644 --- a/src/simpletci_tensors.jl +++ b/src/simpletci_tensors.jl @@ -1,3 +1,20 @@ +@doc """ + fillsitetensors( + tci::SimpleTCI{ValueType}, + f; + center_vertex::Int = 0, + ) where {ValueType} + + Fill the site tensors by using a SimpleTCI instance for a tree tensor network. + Center vertex is the vertex of the canonical center of the tree tensor network. + + # Arguments + - `tci::SimpleTCI{ValueType}`: The SimpleTCI instance to fill the site tensors for. + - `f`: The function to use for filling the site tensors. + + # Returns + - `sitetensors::Vector{Pair{Array{ValueType},Vector{NamedEdge}}}`: The site tensors and the edges connecting them to form TensorNetwork by using the SimpleTCI instance. +""" function fillsitetensors( tci::SimpleTCI{ValueType}, f; @@ -114,6 +131,7 @@ function filltensor( Outkeys::Vector{SubTreeVertex}, ::Val{M}, )::Array{ValueType} where {ValueType,M} + N = length(localdims) nin = sum([length(first(IJset[key])) for key in Inkeys]) nout = sum([length(first(IJset[key])) for key in Outkeys]) @@ -136,6 +154,7 @@ function filltensor( ) end + function _call( ::Type{V}, f, @@ -191,7 +210,32 @@ function _call( return result end -function edgeInIJkeys(g::NamedGraph, v::Int, combinededges) +@doc """ + edgeInIJkeys(g::NamedGraph, v::Int, combinededges) + + Get the pivots keys for the incoming direction at the vertex from connecting edges. + + # Arguments + - `g::NamedGraph`: The graph to get the index sets for. + - `v::Int`: The vertex to get the index sets for. + - `combinededges::Union{NamedEdge,Vector{NamedEdge}}`: The edges to get the index sets for. + + # Returns + - `keys::Vector{SubTreeVertex}`: The pivots keys for the incoming direction at the vertex from connecting edges. + + # Example + ```julia + g = NamedGraph([1, 2, 3, 4]) + add_edge!(g, 1 => 2) + add_edge!(g, 2 => 3) + add_edge!(g, 3 => 4) + @show edgeInIJkeys(g, 2, [2 => 3]) + # [SubTreeVertex([2, 3, 4])] + @show edgeInIJkeys(g, 2, [1 => 2]) + # [SubTreeVertex([1, 2])] + ``` +""" +function edgeInIJkeys(g::NamedGraph, v::Int, combinededges)::Vector{SubTreeVertex} if combinededges isa NamedEdge combinededges = [combinededges] end diff --git a/src/sweep2sitepathproposer.jl b/src/sweep2sitepathproposer.jl index d09eded..ae1139b 100644 --- a/src/sweep2sitepathproposer.jl +++ b/src/sweep2sitepathproposer.jl @@ -48,7 +48,7 @@ function generate_sweep2site_path( ) where {ValueType} edge_path = Vector{NamedEdge}() - n = length(tci.localdims) # TODO: Implement for AbstractTreeTensorNetwork + n = length(vertices(tci.g)) # choose the center bond id. if origin_edge == undef @@ -72,7 +72,8 @@ function generate_sweep2site_path( while true candidates = candidateedges(tci.g, center_edge) - candidates = filter(e -> flags[e] == 0, candidates) + candidates = [e for e in candidates if flags[e] == 0] + # If candidates is empty, exit while loop if isempty(candidates) diff --git a/src/treegraph_utils.jl b/src/treegraph_utils.jl index a01ace9..e0c4a86 100644 --- a/src/treegraph_utils.jl +++ b/src/treegraph_utils.jl @@ -14,7 +14,7 @@ function subtreevertices( if children isa Int children = [children] end - grandchildren = [] + grandchildren = Int[] for child in children candidates = outneighbors(g, child) candidates = [cand for cand in candidates if cand != parent] @@ -52,9 +52,12 @@ end function candidateedges(g::NamedGraph, edge::NamedEdge)::Vector{NamedEdge} p, q = separatevertices(g, edge) - candidates = - adjacentedges(g, p; combinededges = edge) ∪ - adjacentedges(g, q; combinededges = edge) + candidates = unique( + vcat( + adjacentedges(g, p; combinededges = edge), + adjacentedges(g, q; combinededges = edge), + ), + ) return candidates end diff --git a/src/treetensornetwork.jl b/src/treetensornetwork.jl index d78a691..1499d90 100644 --- a/src/treetensornetwork.jl +++ b/src/treetensornetwork.jl @@ -24,6 +24,58 @@ mutable struct TreeTensorNetwork{ValueType} end end +@doc """ + function crossinterpolate( + ::Type{ValueType}, + f, + localdims::Union{Vector{Int},NTuple{N,Int}}, + g::NamedGraph, + initialpivots::Vector{MultiIndex}=[ones(Int, length(localdims))]; + kwargs... + ) where {ValueType,N} + +Cross interpolate a function using the 2-site TCI algorithm. + +# Arguments: +- `ValueType` is the return type of `f`. Automatic inference is too error-prone. +- `localdims::Union{Vector{Int},NTuple{N,Int}}` is a `Vector` (or `Tuple`) that contains the local dimension of each index of `f`. +- `f` is the function to be interpolated. `f` should have a single parameter, which is a vector of the same length as `localdims`. The return type should be `ValueType`. +- `g::NamedGraph` is the graph on which the function is defined. +- `initialpivots::Vector{MultiIndex}` is a vector of pivots to be used for initialization. Default: `[1, 1, ...]`. + +# Keywords +- `tolerance::Union{Float64,Nothing} = nothing`: Error tolerance for convergence +- `maxbonddim::Int = typemax(Int)`: Maximum bond dimension +- `maxiter::Int = 20`: Maximum number of iterations +- `sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer()`: Strategy for sweeping +- `pivotstrategy::AbstractPivotCandidateProposer = DefaultPivotCandidateProposer()`: Strategy for proposing pivot candidates +- `verbosity::Int = 0`: Verbosity level +- `loginterval::Int = 10`: Interval for logging +- `normalizeerror::Bool = true`: Whether to normalize errors +- `ncheckhistory::Int = 3`: Number of history steps to check + +# Returns +- `ttn::TreeTensorNetwork`: A `TreeTensorNetwork` object +- `ranks::Vector{Int}`: The ranks of the tensors in the `TreeTensorNetwork` +- `errors::Vector{Float64}`: The errors of the tensors in the `TreeTensorNetwork` + +Notes: +- Set `tolerance` to be > 0 or `maxbonddim` to some reasonable value. Otherwise, convergence is not reachable. +- By default, no caching takes place. Use the [`CachedFunction`](@ref) wrapper if your function is expensive to evaluate. + +# Example +```julia +g = NamedGraph([1, 2, 3]) +add_edge!(g, 1 => 2) +add_edge!(g, 2 => 3) +f(x) = x[1] + x[2] * x[3] +localdims = [2, 2, 2] +ttn = crossinterpolate(f, localdims, g) +``` + +See also: [`optimize!`](@ref) +""" + function crossinterpolate( ::Type{ValueType}, f, diff --git a/test/simpletci_tests.jl b/test/simpletci_tests.jl index 05b707a..1cd2016 100644 --- a/test/simpletci_tests.jl +++ b/test/simpletci_tests.jl @@ -14,8 +14,12 @@ localdims = fill(2, length(vertices(g))) f(v) = 1 / (1 + v' * v) - - ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g) + kwargs = ( + maxbonddim = 5, + maxiter = 10, + pivotstrategy = TreeTCI.SimplePivotCandidateProposer(), + ) + ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g; kwargs...) @test ttn([1, 1, 1, 1, 1, 1, 1]) ≈ f([1, 1, 1, 1, 1, 1, 1]) @test ttn([1, 1, 1, 1, 2, 2, 2]) ≈ f([1, 1, 1, 1, 2, 2, 2]) @test ttn([2, 2, 2, 2, 1, 1, 1]) ≈ f([2, 2, 2, 2, 1, 1, 1])