diff --git a/src/tensorci2.jl b/src/tensorci2.jl index 1c4359b..690a7d5 100644 --- a/src/tensorci2.jl +++ b/src/tensorci2.jl @@ -53,6 +53,25 @@ function TensorCI2{ValueType}( return tci end +""" + Initialize a TCI2 object with local pivot lists. +""" +function TensorCI2{ValueType}( + func::F, + localdims::Union{Vector{Int},NTuple{N,Int}}, + Iset::Vector{Vector{MultiIndex}}, + Jset::Vector{Vector{MultiIndex}} +) where {F,ValueType,N} + tci = TensorCI2{ValueType}(localdims) + tci.Iset = Iset + tci.Jset = Jset + pivots = reconstractglobalpivotsfromijset(localdims, tci.Iset, tci.Jset) + tci.maxsamplevalue = maximum(abs, (func(bit) for bit in pivots)) + abs(tci.maxsamplevalue) > 0.0 || error("maxsamplevalue is zero!") + invalidatesitetensors!(tci) + return tci +end + @doc raw""" function printnestinginfo(tci::TensorCI2{T}) where {T} @@ -150,6 +169,24 @@ function updateerrors!( nothing end +function reconstractglobalpivotsfromijset( + localdims::Union{Vector{Int},NTuple{N,Int}}, + Isets::Vector{Vector{MultiIndex}}, + Jsets::Vector{Vector{MultiIndex}} +) where {N} + pivots = [] + l = length(Isets) + for i in 1:l + for Iset in Isets[i] + for Jset in Jsets[i] + for j in 1:localdims[i] + pushunique!(pivots, vcat(Iset, [j], Jset)) + end + end + end + end + return pivots +end """ Add global pivots to index sets diff --git a/test/test_tensorci2.jl b/test/test_tensorci2.jl index efb6e1c..090defc 100644 --- a/test/test_tensorci2.jl +++ b/test/test_tensorci2.jl @@ -393,6 +393,22 @@ import QuanticsGrids as QD end + @testset "initialize_with_local_pivots_list" begin + Random.seed!(1234) + + N = 10 + M = rand(Float64, N, N) + f(v) = M[v[1], v[2]] # 2D function + localdims = fill(N, 2) + mbd = 5 + + tci, ranks, errors = TCI.crossinterpolate2(Float64, f, localdims; maxbonddim=mbd) + tci2 = TCI.TensorCI2{Float64}(f, localdims, tci.Iset, tci.Jset) + @test tci2.maxsamplevalue == tci.maxsamplevalue + @test tci2.Iset == tci.Iset + @test tci2.Jset == tci.Jset + end + @testset "crossinterpolate2_ttcache" begin ValueType = Float64