Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
BitIntegers = "0.3.5"
EllipsisNotation = "1"
QuadGK = "2.9"
Random = "1.10.0"
julia = "1.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion docs/src/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Pages = ["tensorci1.jl", "indexset.jl", "sweepstrategies.jl"]
### Tensor cross interpolation 2 (TCI2)
```@autodocs
Modules = [TensorCrossInterpolation]
Pages = ["tensorci2.jl"]
Pages = ["tensorci2.jl", "globalpivotfinder.jl"]
```

### Integration
Expand Down
63 changes: 62 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ end

`CachedFunction{T}` can wrap a function inheriting from `BatchEvaluator{T}`. In such cases, `CachedFunction{T}` caches the results of batch evaluation.

# Batch evaluation + parallelization
## Batch evaluation + parallelization
The batch evalution can be combined with parallelization using threads, MPI, etc.
The following sample code use `Threads` to parallelize function evaluations.
Note that the function evaluation for a single index set must be thread-safe.
Expand Down Expand Up @@ -350,3 +350,64 @@ end
```

You can simply pass the wrapped function `parf` to `crossinterpolate2`.

## Global pivot finder
A each TCI2 sweep, we can find the index sets with high interpolation error and add them to the TCI2 object.
By default, we use a greedy search algorithm to find the index sets with high interpolation error.
However, this may not be effective in some cases.
In such cases, you can use a custom global pivot finder, which must inherit from `TCI.AbstractGlobalPivotFinder`.

Here's an example of a custom global pivot finder that randomly selects pivots:

```julia
import TensorCrossInterpolation as TCI

struct CustomGlobalPivotFinder <: TCI.AbstractGlobalPivotFinder
npivots::Int
end

function (finder::CustomGlobalPivotFinder)(
tci::TensorCI2{ValueType},
f,
abstol::Float64;
verbosity::Int=0
)::Vector{MultiIndex} where {ValueType}
L = length(tci.localdims)
return [[rand(1:tci.localdims[p]) for p in 1:L] for _ in 1:finder.npivots]
end
```

You can use this custom finder by passing it to the `optimize!` function:

```julia
tci, ranks, errors = crossinterpolate2(
Float64,
f,
localdims,
firstpivots;
globalpivotfinder=CustomGlobalPivotFinder(10) # Use custom finder that adds 10 random pivots
)
```

The default global pivot finder (`DefaultGlobalPivotFinder`) uses a greedy search algorithm to find index sets with high interpolation error. It has the following parameters:

- `nsearch`: Number of initial points to search from (default: 5)
- `maxnglobalpivot`: Maximum number of pivots to add in each iteration (default: 5)
- `tolmarginglobalsearch`: Search for pivots where the interpolation error is larger than the tolerance multiplied by this factor (default: 10.0)

You can customize these parameters by creating a `DefaultGlobalPivotFinder` instance:

```julia
finder = TCI.DefaultGlobalPivotFinder(
nsearch=10, # Search from 10 initial points
maxnglobalpivot=3, # Add at most 3 pivots per iteration
tolmarginglobalsearch=5.0 # Search for errors > 5 * tolerance
)
tci, ranks, errors = crossinterpolate2(
Float64,
f,
localdims,
firstpivots;
globalpivotfinder=finder
)
```
4 changes: 3 additions & 1 deletion src/TensorCrossInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Base: ==, +
# To define iterators and element access for MCI, TCI and TT objects
import Base: isempty, iterate, getindex, lastindex, broadcastable
import Base: length, size, sum
import Random

export crossinterpolate1, crossinterpolate2, optfirstpivot
export tensortrain, TensorTrain, sitedims, evaluate
Expand All @@ -31,9 +32,10 @@ include("abstracttensortrain.jl")
include("cachedtensortrain.jl")
include("batcheval.jl")
include("cachedfunction.jl")
include("tensortrain.jl")
include("tensorci1.jl")
include("globalpivotfinder.jl")
include("tensorci2.jl")
include("tensortrain.jl")
include("conversion.jl")
include("integration.jl")
include("contraction.jl")
Expand Down
195 changes: 195 additions & 0 deletions src/globalpivotfinder.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import Random: AbstractRNG, default_rng

"""
GlobalPivotSearchInput{ValueType}

Input data structure for global pivot search algorithms.

# Fields
- `localdims::Vector{Int}`: Dimensions of each tensor index
- `current_tt::TensorTrain{ValueType,3}`: Current tensor train approximation
- `maxsamplevalue::ValueType`: Maximum absolute value of the function
- `Iset::Vector{Vector{MultiIndex}}`: Set of left indices
- `Jset::Vector{Vector{MultiIndex}}`: Set of right indices
"""
struct GlobalPivotSearchInput{ValueType}
localdims::Vector{Int}
current_tt::TensorTrain{ValueType,3}
maxsamplevalue::Float64
Iset::Vector{Vector{MultiIndex}}
Jset::Vector{Vector{MultiIndex}}

"""
GlobalPivotSearchInput(
localdims::Vector{Int},
current_tt::TensorTrain{ValueType,3},
maxsamplevalue::ValueType,
Iset::Vector{Vector{MultiIndex}},
Jset::Vector{Vector{MultiIndex}}
) where {ValueType}

Construct a GlobalPivotSearchInput with the given fields.
"""
function GlobalPivotSearchInput{ValueType}(
localdims::Vector{Int},
current_tt::TensorTrain{ValueType,3},
maxsamplevalue::Float64,
Iset::Vector{Vector{MultiIndex}},
Jset::Vector{Vector{MultiIndex}}
) where {ValueType}
new{ValueType}(
localdims,
current_tt,
maxsamplevalue,
Iset,
Jset
)
end
end


"""
AbstractGlobalPivotFinder

Abstract type for global pivot finders that search for indices with high interpolation error.
"""
abstract type AbstractGlobalPivotFinder end

"""
(finder::AbstractGlobalPivotFinder)(
input::GlobalPivotSearchInput{ValueType},
f,
abstol::Float64;
verbosity::Int=0,
rng::AbstractRNG=Random.default_rng()
)::Vector{MultiIndex} where {ValueType}

Find global pivots using the given finder algorithm.

# Arguments
- `input`: Input data for the search algorithm
- `f`: Function to be interpolated
- `abstol`: Absolute tolerance for the interpolation error
- `verbosity`: Verbosity level (default: 0)
- `rng`: Random number generator (default: Random.default_rng())

# Returns
- `Vector{MultiIndex}`: Set of indices with high interpolation error
"""
function (finder::AbstractGlobalPivotFinder)(
input::GlobalPivotSearchInput{ValueType},
f,
abstol::Float64;
verbosity::Int=0,
rng::AbstractRNG=Random.default_rng()
)::Vector{MultiIndex} where {ValueType}
error("find_global_pivots not implemented for $(typeof(finder))")
end

"""
DefaultGlobalPivotFinder

Default implementation of global pivot finder that uses random search.

# Fields
- `nsearch::Int`: Number of initial points to search from
- `maxnglobalpivot::Int`: Maximum number of pivots to add in each iteration
- `tolmarginglobalsearch::Float64`: Search for pivots where the interpolation error is larger than the tolerance multiplied by this factor
"""
struct DefaultGlobalPivotFinder <: AbstractGlobalPivotFinder
nsearch::Int
maxnglobalpivot::Int
tolmarginglobalsearch::Float64
end

"""
DefaultGlobalPivotFinder(;
nsearch::Int=5,
maxnglobalpivot::Int=5,
tolmarginglobalsearch::Float64=10.0
)

Construct a DefaultGlobalPivotFinder with the given parameters.
"""
function DefaultGlobalPivotFinder(;
nsearch::Int=5,
maxnglobalpivot::Int=5,
tolmarginglobalsearch::Float64=10.0
)
return DefaultGlobalPivotFinder(nsearch, maxnglobalpivot, tolmarginglobalsearch)
end

"""
(finder::DefaultGlobalPivotFinder)(
input::GlobalPivotSearchInput{ValueType},
f,
abstol::Float64;
verbosity::Int=0,
rng::AbstractRNG=Random.default_rng()
)::Vector{MultiIndex} where {ValueType}

Find global pivots using random search.

# Arguments
- `input`: Input data for the search algorithm
- `f`: Function to be interpolated
- `abstol`: Absolute tolerance for the interpolation error
- `verbosity`: Verbosity level (default: 0)
- `rng`: Random number generator (default: Random.default_rng())

# Returns
- `Vector{MultiIndex}`: Set of indices with high interpolation error
"""
function (finder::DefaultGlobalPivotFinder)(
input::GlobalPivotSearchInput{ValueType},
f,
abstol::Float64;
verbosity::Int=0,
rng::AbstractRNG=Random.default_rng()
)::Vector{MultiIndex} where {ValueType}
L = length(input.localdims)
nsearch = finder.nsearch
maxnglobalpivot = finder.maxnglobalpivot
tolmarginglobalsearch = finder.tolmarginglobalsearch

# Generate random initial points
initial_points = [[rand(rng, 1:input.localdims[p]) for p in 1:L] for _ in 1:nsearch]

# Find pivots with high interpolation error
found_pivots = MultiIndex[]
for point in initial_points
# Perform local search from each initial point
current_point = copy(point)
best_error = 0.0
best_point = copy(point)

# Local search
for p in 1:L
for v in 1:input.localdims[p]
current_point[p] = v
error = abs(f(current_point) - input.current_tt(current_point))
if error > best_error
best_error = error
best_point = copy(current_point)
end
end
current_point[p] = point[p] # Reset to original point
end

# Add point if error is above threshold
if best_error > abstol * tolmarginglobalsearch
push!(found_pivots, best_point)
end
end

# Limit number of pivots
if length(found_pivots) > maxnglobalpivot
found_pivots = found_pivots[1:maxnglobalpivot]
end

if verbosity > 0
println("Found $(length(found_pivots)) global pivots")
end

return found_pivots
end
Loading
Loading