diff --git a/Project.toml b/Project.toml index b1c4a0f..b23bdcb 100644 --- a/Project.toml +++ b/Project.toml @@ -8,11 +8,13 @@ BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] BitIntegers = "0.3.5" EllipsisNotation = "1" QuadGK = "2.9" +Random = "1.10.0" julia = "1.6" [extras] diff --git a/docs/src/documentation.md b/docs/src/documentation.md index 719b589..2c28b27 100644 --- a/docs/src/documentation.md +++ b/docs/src/documentation.md @@ -41,7 +41,7 @@ Pages = ["tensorci1.jl", "indexset.jl", "sweepstrategies.jl"] ### Tensor cross interpolation 2 (TCI2) ```@autodocs Modules = [TensorCrossInterpolation] -Pages = ["tensorci2.jl"] +Pages = ["tensorci2.jl", "globalpivotfinder.jl"] ``` ### Integration diff --git a/docs/src/index.md b/docs/src/index.md index 7beec0f..26fd390 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -242,7 +242,7 @@ end `CachedFunction{T}` can wrap a function inheriting from `BatchEvaluator{T}`. In such cases, `CachedFunction{T}` caches the results of batch evaluation. -# Batch evaluation + parallelization +## Batch evaluation + parallelization The batch evalution can be combined with parallelization using threads, MPI, etc. The following sample code use `Threads` to parallelize function evaluations. Note that the function evaluation for a single index set must be thread-safe. @@ -350,3 +350,64 @@ end ``` You can simply pass the wrapped function `parf` to `crossinterpolate2`. + +## Global pivot finder +A each TCI2 sweep, we can find the index sets with high interpolation error and add them to the TCI2 object. +By default, we use a greedy search algorithm to find the index sets with high interpolation error. +However, this may not be effective in some cases. +In such cases, you can use a custom global pivot finder, which must inherit from `TCI.AbstractGlobalPivotFinder`. + +Here's an example of a custom global pivot finder that randomly selects pivots: + +```julia +import TensorCrossInterpolation as TCI + +struct CustomGlobalPivotFinder <: TCI.AbstractGlobalPivotFinder + npivots::Int +end + +function (finder::CustomGlobalPivotFinder)( + tci::TensorCI2{ValueType}, + f, + abstol::Float64; + verbosity::Int=0 +)::Vector{MultiIndex} where {ValueType} + L = length(tci.localdims) + return [[rand(1:tci.localdims[p]) for p in 1:L] for _ in 1:finder.npivots] +end +``` + +You can use this custom finder by passing it to the `optimize!` function: + +```julia +tci, ranks, errors = crossinterpolate2( + Float64, + f, + localdims, + firstpivots; + globalpivotfinder=CustomGlobalPivotFinder(10) # Use custom finder that adds 10 random pivots +) +``` + +The default global pivot finder (`DefaultGlobalPivotFinder`) uses a greedy search algorithm to find index sets with high interpolation error. It has the following parameters: + +- `nsearch`: Number of initial points to search from (default: 5) +- `maxnglobalpivot`: Maximum number of pivots to add in each iteration (default: 5) +- `tolmarginglobalsearch`: Search for pivots where the interpolation error is larger than the tolerance multiplied by this factor (default: 10.0) + +You can customize these parameters by creating a `DefaultGlobalPivotFinder` instance: + +```julia +finder = TCI.DefaultGlobalPivotFinder( + nsearch=10, # Search from 10 initial points + maxnglobalpivot=3, # Add at most 3 pivots per iteration + tolmarginglobalsearch=5.0 # Search for errors > 5 * tolerance +) +tci, ranks, errors = crossinterpolate2( + Float64, + f, + localdims, + firstpivots; + globalpivotfinder=finder +) +``` diff --git a/src/TensorCrossInterpolation.jl b/src/TensorCrossInterpolation.jl index 4920c0b..cad8cd5 100644 --- a/src/TensorCrossInterpolation.jl +++ b/src/TensorCrossInterpolation.jl @@ -13,6 +13,7 @@ import Base: ==, + # To define iterators and element access for MCI, TCI and TT objects import Base: isempty, iterate, getindex, lastindex, broadcastable import Base: length, size, sum +import Random export crossinterpolate1, crossinterpolate2, optfirstpivot export tensortrain, TensorTrain, sitedims, evaluate @@ -31,9 +32,10 @@ include("abstracttensortrain.jl") include("cachedtensortrain.jl") include("batcheval.jl") include("cachedfunction.jl") +include("tensortrain.jl") include("tensorci1.jl") +include("globalpivotfinder.jl") include("tensorci2.jl") -include("tensortrain.jl") include("conversion.jl") include("integration.jl") include("contraction.jl") diff --git a/src/globalpivotfinder.jl b/src/globalpivotfinder.jl new file mode 100644 index 0000000..76a867d --- /dev/null +++ b/src/globalpivotfinder.jl @@ -0,0 +1,195 @@ +import Random: AbstractRNG, default_rng + +""" + GlobalPivotSearchInput{ValueType} + +Input data structure for global pivot search algorithms. + +# Fields +- `localdims::Vector{Int}`: Dimensions of each tensor index +- `current_tt::TensorTrain{ValueType,3}`: Current tensor train approximation +- `maxsamplevalue::ValueType`: Maximum absolute value of the function +- `Iset::Vector{Vector{MultiIndex}}`: Set of left indices +- `Jset::Vector{Vector{MultiIndex}}`: Set of right indices +""" +struct GlobalPivotSearchInput{ValueType} + localdims::Vector{Int} + current_tt::TensorTrain{ValueType,3} + maxsamplevalue::Float64 + Iset::Vector{Vector{MultiIndex}} + Jset::Vector{Vector{MultiIndex}} + + """ + GlobalPivotSearchInput( + localdims::Vector{Int}, + current_tt::TensorTrain{ValueType,3}, + maxsamplevalue::ValueType, + Iset::Vector{Vector{MultiIndex}}, + Jset::Vector{Vector{MultiIndex}} + ) where {ValueType} + + Construct a GlobalPivotSearchInput with the given fields. + """ + function GlobalPivotSearchInput{ValueType}( + localdims::Vector{Int}, + current_tt::TensorTrain{ValueType,3}, + maxsamplevalue::Float64, + Iset::Vector{Vector{MultiIndex}}, + Jset::Vector{Vector{MultiIndex}} + ) where {ValueType} + new{ValueType}( + localdims, + current_tt, + maxsamplevalue, + Iset, + Jset + ) + end +end + + +""" + AbstractGlobalPivotFinder + +Abstract type for global pivot finders that search for indices with high interpolation error. +""" +abstract type AbstractGlobalPivotFinder end + +""" + (finder::AbstractGlobalPivotFinder)( + input::GlobalPivotSearchInput{ValueType}, + f, + abstol::Float64; + verbosity::Int=0, + rng::AbstractRNG=Random.default_rng() + )::Vector{MultiIndex} where {ValueType} + +Find global pivots using the given finder algorithm. + +# Arguments +- `input`: Input data for the search algorithm +- `f`: Function to be interpolated +- `abstol`: Absolute tolerance for the interpolation error +- `verbosity`: Verbosity level (default: 0) +- `rng`: Random number generator (default: Random.default_rng()) + +# Returns +- `Vector{MultiIndex}`: Set of indices with high interpolation error +""" +function (finder::AbstractGlobalPivotFinder)( + input::GlobalPivotSearchInput{ValueType}, + f, + abstol::Float64; + verbosity::Int=0, + rng::AbstractRNG=Random.default_rng() +)::Vector{MultiIndex} where {ValueType} + error("find_global_pivots not implemented for $(typeof(finder))") +end + +""" + DefaultGlobalPivotFinder + +Default implementation of global pivot finder that uses random search. + +# Fields +- `nsearch::Int`: Number of initial points to search from +- `maxnglobalpivot::Int`: Maximum number of pivots to add in each iteration +- `tolmarginglobalsearch::Float64`: Search for pivots where the interpolation error is larger than the tolerance multiplied by this factor +""" +struct DefaultGlobalPivotFinder <: AbstractGlobalPivotFinder + nsearch::Int + maxnglobalpivot::Int + tolmarginglobalsearch::Float64 +end + +""" + DefaultGlobalPivotFinder(; + nsearch::Int=5, + maxnglobalpivot::Int=5, + tolmarginglobalsearch::Float64=10.0 + ) + +Construct a DefaultGlobalPivotFinder with the given parameters. +""" +function DefaultGlobalPivotFinder(; + nsearch::Int=5, + maxnglobalpivot::Int=5, + tolmarginglobalsearch::Float64=10.0 +) + return DefaultGlobalPivotFinder(nsearch, maxnglobalpivot, tolmarginglobalsearch) +end + +""" + (finder::DefaultGlobalPivotFinder)( + input::GlobalPivotSearchInput{ValueType}, + f, + abstol::Float64; + verbosity::Int=0, + rng::AbstractRNG=Random.default_rng() + )::Vector{MultiIndex} where {ValueType} + +Find global pivots using random search. + +# Arguments +- `input`: Input data for the search algorithm +- `f`: Function to be interpolated +- `abstol`: Absolute tolerance for the interpolation error +- `verbosity`: Verbosity level (default: 0) +- `rng`: Random number generator (default: Random.default_rng()) + +# Returns +- `Vector{MultiIndex}`: Set of indices with high interpolation error +""" +function (finder::DefaultGlobalPivotFinder)( + input::GlobalPivotSearchInput{ValueType}, + f, + abstol::Float64; + verbosity::Int=0, + rng::AbstractRNG=Random.default_rng() +)::Vector{MultiIndex} where {ValueType} + L = length(input.localdims) + nsearch = finder.nsearch + maxnglobalpivot = finder.maxnglobalpivot + tolmarginglobalsearch = finder.tolmarginglobalsearch + + # Generate random initial points + initial_points = [[rand(rng, 1:input.localdims[p]) for p in 1:L] for _ in 1:nsearch] + + # Find pivots with high interpolation error + found_pivots = MultiIndex[] + for point in initial_points + # Perform local search from each initial point + current_point = copy(point) + best_error = 0.0 + best_point = copy(point) + + # Local search + for p in 1:L + for v in 1:input.localdims[p] + current_point[p] = v + error = abs(f(current_point) - input.current_tt(current_point)) + if error > best_error + best_error = error + best_point = copy(current_point) + end + end + current_point[p] = point[p] # Reset to original point + end + + # Add point if error is above threshold + if best_error > abstol * tolmarginglobalsearch + push!(found_pivots, best_point) + end + end + + # Limit number of pivots + if length(found_pivots) > maxnglobalpivot + found_pivots = found_pivots[1:maxnglobalpivot] + end + + if verbosity > 0 + println("Found $(length(found_pivots)) global pivots") + end + + return found_pivots +end \ No newline at end of file diff --git a/src/tensorci2.jl b/src/tensorci2.jl index 690a7d5..c85db93 100644 --- a/src/tensorci2.jl +++ b/src/tensorci2.jl @@ -1,4 +1,3 @@ - """ mutable struct TensorCI2{ValueType} <: AbstractTensorTrain{ValueType} @@ -613,7 +612,8 @@ function convergencecriterion( nglobalpivots::AbstractVector{Int}, tolerance::Float64, maxbonddim::Int, - ncheckhistory::Int, + ncheckhistory::Int; + checkconvglobalpivot::Bool=true )::Bool if length(errors) < ncheckhistory return false @@ -622,12 +622,27 @@ function convergencecriterion( lastngpivots = last(nglobalpivots, ncheckhistory) return ( all(last(errors, ncheckhistory) .< tolerance) && - all(lastngpivots .== 0) && + (checkconvglobalpivot ? all(lastngpivots .== 0) : true) && minimum(lastranks) == lastranks[end] ) || all(lastranks .>= maxbonddim) end +""" + GlobalPivotSearchInput(tci::TensorCI2{ValueType}) where {ValueType} + +Construct a GlobalPivotSearchInput from a TensorCI2 object. +""" +function GlobalPivotSearchInput(tci::TensorCI2{ValueType}) where {ValueType} + return GlobalPivotSearchInput{ValueType}( + tci.localdims, + TensorTrain(tci), + tci.maxsamplevalue, + tci.Iset, + tci.Jset + ) +end + """ function optimize!( @@ -664,11 +679,16 @@ Arguments: - `loginterval::Int` can be set to `>= 1` to specify how frequently to print convergence information. Default: `10`. - `normalizeerror::Bool` determines whether to scale the error by the maximum absolute value of `f` found during sampling. If set to `false`, the algorithm continues until the *absolute* error is below `tolerance`. If set to `true`, the algorithm uses the absolute error divided by the maximum sample instead. This is helpful if the magnitude of the function is not known in advance. Default: `true`. - `ncheckhistory::Int` is the number of history points to use for convergence checks. Default: `3`. -- `maxnglobalpivot::Int` can be set to `>= 0`. Default: `5`. -- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `5`. -- `tolmarginglobalsearch` can be set to `>= 1.0`. Seach global pivots where the interpolation error is larger than the tolerance by `tolmarginglobalsearch`. Default: `10.0`. +- `globalpivotfinder::Union{AbstractGlobalPivotFinder, Nothing}` is a global pivot finder to use for searching global pivots. Default: `nothing`. If `nothing`, a default global pivot finder is used. +- `maxnglobalpivot::Int` can be set to `>= 0`. Default: `5`. The maximum number of global pivots to add in each iteration. - `strictlynested::Bool` determines whether to preserve partial nesting in the TCI algorithm. Default: `false`. - `checkbatchevaluatable::Bool` Check if the function `f` is batch evaluatable. Default: `false`. +- `checkconvglobalpivot::Bool` Check if the global pivot finder is converged. Default: `true`. In the future, this will be set to `false` by default. + +Arguments (deprecated): +- `pivottolerance::Float64` is the tolerance for the pivot search. Deprecated. +- `nsearchglobalpivot::Int` is the number of search points for the global pivot finder. Deprecated. +- `tolmarginglobalsearch::Float64` is the tolerance for the global pivot finder. Deprecated. Notes: - Set `tolerance` to be > 0 or `maxbonddim` to some reasonable value. Otherwise, convergence is not reachable. @@ -690,11 +710,13 @@ function optimize!( loginterval::Int=10, normalizeerror::Bool=true, ncheckhistory::Int=3, + globalpivotfinder::Union{AbstractGlobalPivotFinder, Nothing}=nothing, maxnglobalpivot::Int=5, nsearchglobalpivot::Int=5, tolmarginglobalsearch::Float64=10.0, strictlynested::Bool=false, - checkbatchevaluatable::Bool=false + checkbatchevaluatable::Bool=false, + checkconvglobalpivot::Bool=true ) where {ValueType} errors = Float64[] ranks = Int[] @@ -705,9 +727,6 @@ function optimize!( error("Function `f` is not batch evaluatable") end - #if maxnglobalpivot > 0 && nsearchglobalpivot > 0 - #!strictlynested || error("nglobalpivots > 0 requires strictlynested=false!") - #end if nsearchglobalpivot > 0 && nsearchglobalpivot < maxnglobalpivot error("nsearchglobalpivot < maxnglobalpivot!") end @@ -734,6 +753,17 @@ function optimize!( )) end + # Create the global pivot finder + finder = if isnothing(globalpivotfinder) + DefaultGlobalPivotFinder( + nsearch=nsearchglobalpivot, + maxnglobalpivot=maxnglobalpivot, + tolmarginglobalsearch=tolmarginglobalsearch + ) + else + globalpivotfinder + end + globalpivots = MultiIndex[] for iter in 1:maxiter errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 @@ -771,12 +801,11 @@ function optimize!( end # Find global pivots where the error is too large - # Such gloval pivots are added to the TCI, invalidating site tensors. - globalpivots = searchglobalpivots( - tci, f, tolmarginglobalsearch * abstol, + input = GlobalPivotSearchInput(tci) + globalpivots = finder( + input, f, abstol; verbosity=verbosity, - maxnglobalpivot=maxnglobalpivot, - nsearch=nsearchglobalpivot + rng=Random.default_rng() ) addglobalpivots!(tci, globalpivots) push!(nglobalpivots, length(globalpivots)) @@ -792,7 +821,10 @@ function optimize!( flush(stdout) end if convergencecriterion( - ranks, errors, nglobalpivots, abstol, maxbonddim, ncheckhistory + ranks, errors, + nglobalpivots, + abstol, maxbonddim, ncheckhistory; + checkconvglobalpivot=checkconvglobalpivot ) break end @@ -889,20 +921,7 @@ end f, localdims::Union{Vector{Int},NTuple{N,Int}}, initialpivots::Vector{MultiIndex}=[ones(Int, length(localdims))]; - tolerance::Float64=1e-8, - pivottolerance::Float64=tolerance, - maxbonddim::Int=typemax(Int), - maxiter::Int=200, - sweepstrategy::Symbol=:backandforth, - pivotsearch::Symbol=:full, - verbosity::Int=0, - loginterval::Int=10, - normalizeerror::Bool=true, - ncheckhistory=3, - maxnglobalpivot::Int=5, - nsearchglobalpivot::Int=5, - tolmarginglobalsearch::Float64=10.0, - strictlynested::Bool=false + kwargs... ) where {ValueType,N} Cross interpolate a function ``f(\mathbf{u})`` using the TCI2 algorithm. Here, the domain of ``f`` is ``\mathbf{u} \in [1, \ldots, d_1] \times [1, \ldots, d_2] \times \ldots \times [1, \ldots, d_{\mathscr{L}}]`` and ``d_1 \ldots d_{\mathscr{L}}`` are the local dimensions. @@ -912,27 +931,13 @@ Arguments: - `f` is the function to be interpolated. `f` should have a single parameter, which is a vector of the same length as `localdims`. The return type should be `ValueType`. - `localdims::Union{Vector{Int},NTuple{N,Int}}` is a `Vector` (or `Tuple`) that contains the local dimension of each index of `f`. - `initialpivots::Vector{MultiIndex}` is a vector of pivots to be used for initialization. Default: `[1, 1, ...]`. -- `tolerance::Float64` is a float specifying the target tolerance for the interpolation. Default: `1e-8`. -- `pivottolerance::Float64` is a float that specifies the tolerance for adding new pivots, i.e. the truncation of tensor train bonds. It should be <= tolerance, otherwise convergence may be impossible. Default: `tolerance`. -- `maxbonddim::Int` specifies the maximum bond dimension for the TCI. Default: `typemax(Int)`, i.e. effectively unlimited. -- `maxiter::Int` is the maximum number of iterations (i.e. optimization sweeps) before aborting the TCI construction. Default: `200`. -- `sweepstrategy::Symbol` specifies whether to sweep forward (:forward), backward (:backward), or back and forth (:backandforth) during optimization. Default: `:backandforth`. -- `pivotsearch::Symbol` determins how pivots are searched (`:full` or `:rook`). Default: `:full`. -- `verbosity::Int` can be set to `>= 1` to get convergence information on standard output during optimization. Default: `0`. -- `loginterval::Int` can be set to `>= 1` to specify how frequently to print convergence information. Default: `10`. -- `normalizeerror::Bool` determines whether to scale the error by the maximum absolute value of `f` found during sampling. If set to `false`, the algorithm continues until the *absolute* error is below `tolerance`. If set to `true`, the algorithm uses the absolute error divided by the maximum sample instead. This is helpful if the magnitude of the function is not known in advance. Default: `true`. -- `ncheckhistory::Int` is the number of history points to use for convergence checks. Default: `3`. -- `maxnglobalpivot::Int` can be set to `>= 0`. Default: `5`. -- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `5`. -- `tolmarginglobalsearch` can be set to `>= 1.0`. Seach global pivots where the interpolation error is larger than the tolerance by `tolmarginglobalsearch`. Default: `10.0`. -- `strictlynested::Bool=false` determines whether to preserve partial nesting in the TCI algorithm. Default: `true`. -- `checkbatchevaluatable::Bool` Check if the function `f` is batch evaluatable. Default: `false`. + +Refer to [`optimize!`](@ref) for other keyword arguments such as `tolerance`, `maxbonddim`, `maxiter`. Notes: - Set `tolerance` to be > 0 or `maxbonddim` to some reasonable value. Otherwise, convergence is not reachable. - By default, no caching takes place. Use the [`CachedFunction`](@ref) wrapper if your function is expensive to evaluate. - See also: [`optimize!`](@ref), [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate1`](@ref) """ function crossinterpolate2( @@ -947,7 +952,6 @@ function crossinterpolate2( return tci, ranks, errors end - """ Search global pivots where the interpolation error exceeds `abstol`. """ diff --git a/test/test_tensorci2.jl b/test/test_tensorci2.jl index 090defc..7c9ec81 100644 --- a/test/test_tensorci2.jl +++ b/test/test_tensorci2.jl @@ -2,6 +2,7 @@ using Test import TensorCrossInterpolation as TCI import TensorCrossInterpolation: rank, linkdims, TensorCI2, updatepivots!, addglobalpivots1sitesweep!, MultiIndex, evaluate, crossinterpolate2, pivoterror, tensortrain, optimize! import Random +import Random: AbstractRNG import QuanticsGrids as QD @testset "TensorCI2" begin @@ -99,6 +100,70 @@ import QuanticsGrids as QD end + + + struct CustomGlobalPivotFinder <: TCI.AbstractGlobalPivotFinder + npivots::Int + end + + function (finder::CustomGlobalPivotFinder)( + input::TCI.GlobalPivotSearchInput{ValueType}, + f, + abstol::Float64; + verbosity::Int=0, + rng::AbstractRNG=Random.default_rng() + )::Vector{MultiIndex} where {ValueType} + L = length(input.localdims) + return [[rand(rng, 1:input.localdims[p]) for p in 1:L] for _ in 1:finder.npivots] + end + + @testset "custom global pivot finder" begin + pivotsearch = :full + strictlynested = false + nsearchglobalpivot = 10 + + # f(x) = exp(-x) + Random.seed!(1240) + R = 8 + abstol = 1e-4 + + grid = QD.DiscretizedGrid{1}(R, (0.0,), (1.0,)) + + #index_to_x(i) = (i - 1) / 2^R # x ∈ [0, 1) + fx(x) = exp(-x) + f(bitlist::MultiIndex) = fx(QD.quantics_to_origcoord(grid, bitlist)[1]) + + localdims = fill(2, R) + firstpivots = [ones(Int, R), vcat(1, fill(2, R - 1))] + tci, ranks, errors = crossinterpolate2( + Float64, + f, + localdims, + firstpivots; + tolerance=abstol, + maxbonddim=1, + maxiter=2, + loginterval=1, + verbosity=0, + normalizeerror=false, + globalpivotfinder=CustomGlobalPivotFinder(10), + pivotsearch=pivotsearch, + strictlynested=strictlynested + ) + + @test all(TCI.linkdims(tci) .== 1) + + # Conversion to TT + tt = TCI.TensorTrain(tci) + + for x in [0.1, 0.3, 0.6, 0.9] + indexset = QD.origcoord_to_quantics( + grid, (x,) + ) + @test abs(TCI.evaluate(tci, indexset) - f(indexset)) < abstol + @test abs(TCI.evaluate(tt, indexset) - f(indexset)) < abstol + end + end @testset "trivial MPS(exp), small maxbonddim" begin pivotsearch = :full