diff --git a/src/TreeTCI.jl b/src/TreeTCI.jl index 8fc82a2..759d43c 100644 --- a/src/TreeTCI.jl +++ b/src/TreeTCI.jl @@ -18,8 +18,8 @@ import SimpleTensorNetworks: import Random: shuffle include("treegraph_utils.jl") include("simpletci.jl") -include("pivotcandidateproper.jl") -include("sweep2sitepathproper.jl") +include("pivotcandidateproposer.jl") +include("sweep2sitepathproposer.jl") include("simpletci_optimize.jl") include("simpletci_tensors.jl") include("treetensornetwork.jl") diff --git a/src/pivotcandidateproper.jl b/src/pivotcandidateproposer.jl similarity index 90% rename from src/pivotcandidateproper.jl rename to src/pivotcandidateproposer.jl index 7ec8659..0ac9bac 100644 --- a/src/pivotcandidateproper.jl +++ b/src/pivotcandidateproposer.jl @@ -1,18 +1,18 @@ """ Abstract type for pivot candidate generation strategies """ -abstract type AbstractPivotCandidateProper end +abstract type AbstractPivotCandidateProposer end """ Default strategy that uses kronecker product and union with extra indices """ -struct DefaultPivotCandidateProper <: AbstractPivotCandidateProper end +struct DefaultPivotCandidateProposer <: AbstractPivotCandidateProposer end """ Default strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors """ function generate_pivot_candidates( - ::DefaultPivotCandidateProper, + ::DefaultPivotCandidateProposer, tci::SimpleTCI{ValueType}, edge::NamedEdge, extraIJset::Dict{SubTreeVertex,Vector{MultiIndex}}, @@ -60,6 +60,10 @@ function kronecker( site_index = findfirst(==(site), Outkey) filtered_subregions = filter(x -> x ≠ Set([site]), Outkey) + if site_index === nothing + return MultiIndex[] + end + return MultiIndex[ [is[1:site_index-1]..., j, is[site_index+1:end]...] for is in pivotset, j = 1:localdim diff --git a/src/simpletci.jl b/src/simpletci.jl index a4c41d6..1464827 100644 --- a/src/simpletci.jl +++ b/src/simpletci.jl @@ -93,4 +93,3 @@ function pushunique!(collection, items...) pushunique!(collection, item) end end - diff --git a/src/simpletci_optimize.jl b/src/simpletci_optimize.jl index 7c4d77a..c8aef8d 100644 --- a/src/simpletci_optimize.jl +++ b/src/simpletci_optimize.jl @@ -9,20 +9,13 @@ # Keywords - `tolerance::Union{Float64,Nothing} = nothing`: Error tolerance for convergence - - `pivottolerance::Union{Float64,Nothing} = nothing`: Deprecated, use tolerance instead - `maxbonddim::Int = typemax(Int)`: Maximum bond dimension - `maxiter::Int = 20`: Maximum number of iterations - - `sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper()`: Strategy for sweeping - - `pivotsearch::Symbol = :full`: Strategy for pivot search + - `sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer()`: Strategy for sweeping - `verbosity::Int = 0`: Verbosity level - `loginterval::Int = 10`: Interval for logging - `normalizeerror::Bool = true`: Whether to normalize errors - `ncheckhistory::Int = 3`: Number of history steps to check - - `maxnglobalpivot::Int = 5`: Maximum number of global pivots - - `nsearchglobalpivot::Int = 5`: Number of global pivots to search - - `tolmarginglobalsearch::Float64 = 10.0`: Tolerance margin for global search - - `strictlynested::Bool = false`: Whether to enforce strict nesting - - `checkbatchevaluatable::Bool = false`: Whether to check if function is batch evaluatable # Returns - `ranks`: Vector of ranks at each iteration @@ -31,56 +24,21 @@ function optimize!( tci::SimpleTCI{ValueType}, f; - tolerance::Union{Float64,Nothing} = nothing, - pivottolerance::Union{Float64,Nothing} = nothing, + tolerance::Float64 = 1e-8, maxbonddim::Int = typemax(Int), maxiter::Int = 20, - sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper(), - pivotsearch::Symbol = :full, + sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer(), verbosity::Int = 0, loginterval::Int = 10, normalizeerror::Bool = true, ncheckhistory::Int = 3, - maxnglobalpivot::Int = 5, - nsearchglobalpivot::Int = 5, - tolmarginglobalsearch::Float64 = 10.0, - strictlynested::Bool = false, - checkbatchevaluatable::Bool = false, ) where {ValueType} errors = Float64[] ranks = Int[] - nglobalpivots = Int[] - local tol::Float64 - - if checkbatchevaluatable && !(f isa BatchEvaluator) - error("Function `f` is not batch evaluatable") - end - - if nsearchglobalpivot > 0 && nsearchglobalpivot < maxnglobalpivot - error("nsearchglobalpivot < maxnglobalpivot!") - end - - # Deprecate the pivottolerance option - if !isnothing(pivottolerance) - if !isnothing(tolerance) && (tolerance != pivottolerance) - throw( - ArgumentError( - "Got different values for pivottolerance and tolerance in optimize!(TCI2). For TCI2, both of these options have the same meaning. Please assign only `tolerance`.", - ), - ) - else - @warn "The option `pivottolerance` of `optimize!(tci::TensorCI2, f)` is deprecated. Please update your code to use `tolerance`, as `pivottolerance` will be removed in the future." - tol = pivottolerance - end - elseif !isnothing(tolerance) - tol = tolerance - else # pivottolerance == tolerance == nothing, therefore set tol to default value - tol = 1e-8 - end tstart = time_ns() - if maxbonddim >= typemax(Int) && tol <= 0 + if maxbonddim >= typemax(Int) && tolerance <= 0 throw( ArgumentError( "Specify either tolerance > 0 or some maxbonddim; otherwise, the convergence criterion is not reachable!", @@ -91,7 +49,7 @@ function optimize!( globalpivots = MultiIndex[] for iter = 1:maxiter errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 - abstol = tol * errornormalization + abstol = tolerance * errornormalization if verbosity > 1 println(" Walltime $(1e-9*(time_ns() - tstart)) sec: starting 2site sweep") @@ -102,10 +60,8 @@ function optimize!( tci, f, 2; - iter1 = 1, abstol = abstol, maxbonddim = maxbonddim, - pivotsearch = pivotsearch, verbosity = verbosity, sweepstrategy = sweepstrategy, ) @@ -142,17 +98,15 @@ function sweep2site!( tci::SimpleTCI{ValueType}, f, niter::Int; - iter1::Int = 1, abstol::Float64 = 1e-8, maxbonddim::Int = typemax(Int), - sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper(), - pivotsearch::Symbol = :full, + sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer(), verbosity::Int = 0, ) where {ValueType} edge_path = generate_sweep2site_path(sweepstrategy, tci) - for iter = iter1:iter1+niter-1 + for _ = 1:niter extraIJset = Dict(key => MultiIndex[] for key in keys(tci.IJset)) if length(tci.IJset_history) > 0 extraIJset = tci.IJset_history[end] @@ -205,7 +159,7 @@ function updatepivots!( N = length(tci.localdims) (IJkey, combinedIJset) = - generate_pivot_candidates(DefaultPivotCandidateProper(), tci, edge, extraIJset) + generate_pivot_candidates(DefaultPivotCandidateProposer(), tci, edge, extraIJset) Ikey, Jkey = first(IJkey), last(IJkey) t1 = time_ns() diff --git a/src/sweep2sitepathproper.jl b/src/sweep2sitepathproposer.jl similarity index 86% rename from src/sweep2sitepathproper.jl rename to src/sweep2sitepathproposer.jl index 7b0e930..d09eded 100644 --- a/src/sweep2sitepathproper.jl +++ b/src/sweep2sitepathproposer.jl @@ -1,28 +1,28 @@ """ Abstract type for pivot candidate generation strategies """ -abstract type AbstractSweep2sitePathProper end +abstract type AbstractSweep2sitePathProposer end """ Default strategy """ -struct DefaultSweep2sitePathProper <: AbstractSweep2sitePathProper end +struct DefaultSweep2sitePathProposer <: AbstractSweep2sitePathProposer end """ Random strategy """ -struct RandomSweep2sitePathProper <: AbstractSweep2sitePathProper end +struct RandomSweep2sitePathProposer <: AbstractSweep2sitePathProposer end """ LocalAdjacent strategy """ -struct LocalAdjacentSweep2sitePathProper <: AbstractSweep2sitePathProper end +struct LocalAdjacentSweep2sitePathProposer <: AbstractSweep2sitePathProposer end """ Default strategy that return the sequence path defined by the edges(g) """ function generate_sweep2site_path( - ::DefaultSweep2sitePathProper, + ::DefaultSweep2sitePathProposer, tci::SimpleTCI{ValueType}, ) where {ValueType} return collect(edges(tci.g)) @@ -32,7 +32,7 @@ end Random strategy that returns a random sequence of edges """ function generate_sweep2site_path( - ::RandomSweep2sitePathProper, + ::RandomSweep2sitePathProposer, tci::SimpleTCI{ValueType}, ) where {ValueType} return shuffle(collect(edges(tci.g))) @@ -42,7 +42,7 @@ end LocalAdjacent strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors """ function generate_sweep2site_path( - ::LocalAdjacentSweep2sitePathProper, + ::LocalAdjacentSweep2sitePathProposer, tci::SimpleTCI{ValueType}; origin_edge = undef, ) where {ValueType} diff --git a/src/treetensornetwork.jl b/src/treetensornetwork.jl index 1109d46..d78a691 100644 --- a/src/treetensornetwork.jl +++ b/src/treetensornetwork.jl @@ -64,3 +64,8 @@ end function (ttn::TreeTensorNetwork{V})(indexset) where {V} return evaluate(ttn, indexset) end + +# Add length method for TreeTensorNetwork +function Base.length(ttn::TreeTensorNetwork) + return length(vertices(ttn.tensornetwork.data_graph)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 236fbbd..000c20b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,17 +1,4 @@ -# using ReTestItems: runtests, @testitem -# using TreeTCI: TreeTCI - -# runtests(TreeTCI) - +using ReTestItems: runtests, @testitem using TreeTCI: TreeTCI -using Test - -@testset verbose = true "TreeTCI tests" begin - # @testset "Code quality (Aqua.jl)" begin - # Aqua.test_all(TreeTCI; unbound_args = false, deps_compat = false) - # end - @testset verbose = true "Actual tests" begin - include("simpletci_test.jl") - end -end +runtests(TreeTCI) diff --git a/test/simpletci_test.jl b/test/simpletci_test.jl deleted file mode 100644 index a073ff1..0000000 --- a/test/simpletci_test.jl +++ /dev/null @@ -1,61 +0,0 @@ -using Test -using TreeTCI -import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge - -@testset "simpletci.jl" begin - # make graph - g = NamedGraph(7) - add_edge!(g, 1, 2) - add_edge!(g, 2, 3) - add_edge!(g, 2, 4) - add_edge!(g, 4, 5) - add_edge!(g, 5, 6) - add_edge!(g, 5, 7) - - @testset "TreeGraphUtils" begin - e = NamedEdge(2 => 4) - v1, v2 = TreeTCI.separatevertices(g, e) - @test v1 == 2 - @test v2 == 4 - - Ivertices = TreeTCI.subtreevertices(g, v2 => v1) # 4 -> 2 - Jvertices = TreeTCI.subtreevertices(g, v1 => v2) # 2 -> 4 - - @test Ivertices == [1, 2, 3] - @test Jvertices == [4, 5, 6, 7] - - subregions = TreeTCI.subregionvertices(g, e) - @test first(subregions) == [1, 2, 3] - @test last(subregions) == [4, 5, 6, 7] - - - @test Set(TreeTCI.adjacentedges(g, 4)) == - Set([NamedEdge(2 => 4), NamedEdge(4 => 5)]) - - @test Set(TreeTCI.candidateedges(g, NamedEdge(2 => 4))) == - Set([NamedEdge(1 => 2), NamedEdge(2 => 3), NamedEdge(4 => 5)]) - - @test TreeTCI.distanceedges(g, NamedEdge(2 => 4)) == Dict( - NamedEdge(2 => 4) => 0, - NamedEdge(1 => 2) => 1, - NamedEdge(2 => 3) => 1, - NamedEdge(4 => 5) => 1, - NamedEdge(5 => 6) => 2, - NamedEdge(5 => 7) => 2, - ) - - end - - @testset "SimpleTCI" begin - localdims = fill(2, length(vertices(g))) - f(v) = 1 / (1 + v' * v) - - ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g) - @test ttn([1, 1, 1, 1, 1, 1, 1]) ≈ f([1, 1, 1, 1, 1, 1, 1]) - @test ttn([1, 1, 1, 1, 2, 2, 2]) ≈ f([1, 1, 1, 1, 2, 2, 2]) - @test ttn([2, 2, 2, 2, 1, 1, 1]) ≈ f([2, 2, 2, 2, 1, 1, 1]) - @test ttn([1, 2, 1, 2, 1, 2, 1]) ≈ f([1, 2, 1, 2, 1, 2, 1]) - @test ttn([2, 1, 2, 1, 2, 1, 2]) ≈ f([2, 1, 2, 1, 2, 1, 2]) - @test ttn([2, 2, 2, 2, 2, 2, 2]) ≈ f([2, 2, 2, 2, 2, 2, 2]) - end -end diff --git a/test/simpletci_tests.jl b/test/simpletci_tests.jl new file mode 100644 index 0000000..05b707a --- /dev/null +++ b/test/simpletci_tests.jl @@ -0,0 +1,25 @@ +@testitem "SimpleTCI" begin + using Test + using TreeTCI + import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge + + # make graph + g = NamedGraph(7) + add_edge!(g, 1, 2) + add_edge!(g, 2, 3) + add_edge!(g, 2, 4) + add_edge!(g, 4, 5) + add_edge!(g, 5, 6) + add_edge!(g, 5, 7) + + localdims = fill(2, length(vertices(g))) + f(v) = 1 / (1 + v' * v) + + ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g) + @test ttn([1, 1, 1, 1, 1, 1, 1]) ≈ f([1, 1, 1, 1, 1, 1, 1]) + @test ttn([1, 1, 1, 1, 2, 2, 2]) ≈ f([1, 1, 1, 1, 2, 2, 2]) + @test ttn([2, 2, 2, 2, 1, 1, 1]) ≈ f([2, 2, 2, 2, 1, 1, 1]) + @test ttn([1, 2, 1, 2, 1, 2, 1]) ≈ f([1, 2, 1, 2, 1, 2, 1]) + @test ttn([2, 1, 2, 1, 2, 1, 2]) ≈ f([2, 1, 2, 1, 2, 1, 2]) + @test ttn([2, 2, 2, 2, 2, 2, 2]) ≈ f([2, 2, 2, 2, 2, 2, 2]) +end diff --git a/test/treegraph_utils_tests.jl b/test/treegraph_utils_tests.jl new file mode 100644 index 0000000..16867e8 --- /dev/null +++ b/test/treegraph_utils_tests.jl @@ -0,0 +1,43 @@ +@testitem "TreeGraphUtils" begin + using Test + using TreeTCI + import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge + + # make graph + g = NamedGraph(7) + add_edge!(g, 1, 2) + add_edge!(g, 2, 3) + add_edge!(g, 2, 4) + add_edge!(g, 4, 5) + add_edge!(g, 5, 6) + add_edge!(g, 5, 7) + + e = NamedEdge(2 => 4) + v1, v2 = TreeTCI.separatevertices(g, e) + @test v1 == 2 + @test v2 == 4 + + Ivertices = TreeTCI.subtreevertices(g, v2 => v1) # 4 -> 2 + Jvertices = TreeTCI.subtreevertices(g, v1 => v2) # 2 -> 4 + + @test Ivertices == [1, 2, 3] + @test Jvertices == [4, 5, 6, 7] + + subregions = TreeTCI.subregionvertices(g, e) + @test first(subregions) == [1, 2, 3] + @test last(subregions) == [4, 5, 6, 7] + + @test Set(TreeTCI.adjacentedges(g, 4)) == Set([NamedEdge(2 => 4), NamedEdge(4 => 5)]) + + @test Set(TreeTCI.candidateedges(g, NamedEdge(2 => 4))) == + Set([NamedEdge(1 => 2), NamedEdge(2 => 3), NamedEdge(4 => 5)]) + + @test TreeTCI.distanceedges(g, NamedEdge(2 => 4)) == Dict( + NamedEdge(2 => 4) => 0, + NamedEdge(1 => 2) => 1, + NamedEdge(2 => 3) => 1, + NamedEdge(4 => 5) => 1, + NamedEdge(5 => 6) => 2, + NamedEdge(5 => 7) => 2, + ) +end