Skip to content

Commit 32ff255

Browse files
Implement pivots recycling
1 parent 8afcfe0 commit 32ff255

File tree

3 files changed

+179
-30
lines changed

3 files changed

+179
-30
lines changed

src/crossinterpolate.jl

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ mutable struct TCI2PatchCreator{T} <: AbstractPatchCreator{T,TensorTrainState{T}
122122
checkbatchevaluatable::Bool
123123
loginterval::Int
124124
initialpivots::Vector{MultiIndex} # Make it to Vector{MMultiIndex}?
125+
recyclepivots::Bool
125126
end
126127

127128
function Base.show(io::IO, obj::TCI2PatchCreator{T}) where {T}
@@ -141,6 +142,7 @@ function TCI2PatchCreator{T}(obj::TCI2PatchCreator{T})::TCI2PatchCreator{T} wher
141142
obj.checkbatchevaluatable,
142143
obj.loginterval,
143144
obj.initialpivots,
145+
obj.recyclepivots,
144146
)
145147
end
146148

@@ -157,7 +159,8 @@ function TCI2PatchCreator(
157159
ninitialpivot=5,
158160
checkbatchevaluatable=false,
159161
loginterval=10,
160-
initialpivots=Vector{MultiIndex}[],
162+
initialpivots=MultiIndex[],
163+
recyclepivots=false,
161164
)::TCI2PatchCreator{T} where {T}
162165
#t1 = time_ns()
163166
if projector === nothing
@@ -183,6 +186,7 @@ function TCI2PatchCreator(
183186
checkbatchevaluatable,
184187
loginterval,
185188
initialpivots,
189+
recyclepivots,
186190
)
187191
end
188192

@@ -206,6 +210,7 @@ function _crossinterpolate2!(
206210
verbosity::Int=0,
207211
checkbatchevaluatable=false,
208212
loginterval=10,
213+
recyclepivots=false,
209214
) where {T}
210215
ncheckhistory = 3
211216
ranks, errors = TCI.optimize!(
@@ -231,13 +236,45 @@ function _crossinterpolate2!(
231236
ncheckhistory_ = min(ncheckhistory, length(errors))
232237
maxbonddim_hist = maximum(ranks[(end - ncheckhistory_ + 1):end])
233238

234-
return PatchCreatorResult{T,TensorTrain{T,3}}(
235-
TensorTrain(tci), TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim
236-
)
239+
if recyclepivots
240+
return PatchCreatorResult{T,TensorTrain{T,3}}(
241+
TensorTrain(tci),
242+
TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim,
243+
_globalpivots(tci),
244+
)
245+
246+
else
247+
return PatchCreatorResult{T,TensorTrain{T,3}}(
248+
TensorTrain(tci),
249+
TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim,
250+
)
251+
end
252+
end
253+
254+
# Generating global pivots from local ones
255+
function _globalpivots(
256+
tci::TCI.TensorCI2{T}; onlydiagonal=true
257+
)::Vector{MultiIndex} where {T}
258+
Isets = tci.Iset
259+
Jsets = tci.Jset
260+
L = length(Isets)
261+
p = Set{MultiIndex}()
262+
# Pivot matrices
263+
for bondindex in 1:(L - 1)
264+
if onlydiagonal
265+
for (x, y) in zip(Isets[bondindex + 1], Jsets[bondindex])
266+
push!(p, vcat(x, y))
267+
end
268+
else
269+
for x in Isets[bondindex + 1], y in Jsets[bondindex]
270+
push!(p, vcat(x, y))
271+
end
272+
end
273+
end
274+
return collect(p)
237275
end
238276

239277
function createpatch(obj::TCI2PatchCreator{T}) where {T}
240-
proj = obj.projector
241278
fsubset = _FuncAdapterTCI2Subset(obj.f)
242279

243280
tci = if isapproxttavailable(obj.f)
@@ -253,21 +290,25 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
253290
end
254291
tci
255292
else
256-
# Random initial pivots
257293
initialpivots = MultiIndex[]
258-
let
259-
mask = [!isprojectedat(proj, n) for n in 1:length(proj)]
260-
for idx in obj.initialpivots
261-
idx_ = [[i] for i in idx]
262-
if idx_ <= proj
263-
push!(initialpivots, idx[mask])
264-
end
294+
if obj.recyclepivots
295+
# First patching iteration: random pivots
296+
if length(fsubset.localdims) == length(obj.localdims)
297+
initialpivots = union(
298+
obj.initialpivots,
299+
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
300+
)
301+
# Next iterations: recycle previously generated pivots
302+
else
303+
initialpivots = copy(obj.initialpivots)
265304
end
305+
else
306+
initialpivots = union(
307+
obj.initialpivots,
308+
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
309+
)
266310
end
267-
append!(
268-
initialpivots,
269-
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
270-
)
311+
271312
if all(fsubset.(initialpivots) .== 0)
272313
return PatchCreatorResult{T,TensorTrainState{T}}(nothing, true)
273314
end
@@ -282,6 +323,7 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
282323
verbosity=obj.verbosity,
283324
checkbatchevaluatable=obj.checkbatchevaluatable,
284325
loginterval=obj.loginterval,
326+
recyclepivots=obj.recyclepivots,
285327
)
286328
end
287329

@@ -296,14 +338,14 @@ function adaptiveinterpolate(
296338
end
297339

298340
function adaptiveinterpolate(
299-
f::ProjectableEvaluator{T},
300-
pordering::PatchOrdering=PatchOrdering(collect(1:length(f.sitedims)));
341+
f::ProjectableEvaluator{T};
342+
pordering::PatchOrdering=PatchOrdering(collect(1:length(f.sitedims))),
301343
verbosity=0,
302344
maxbonddim=typemax(Int),
303345
tolerance=1e-8,
304-
initialpivots=Vector{MultiIndex}[], # Make it to Vector{MMultiIndex}?
346+
initialpivots=MultiIndex[], # Make it to Vector{MMultiIndex}?
347+
recyclepivots=false,
305348
)::ProjTTContainer{T} where {T}
306-
t1 = time_ns()
307349
creator = TCI2PatchCreator(
308350
T,
309351
f,
@@ -313,6 +355,7 @@ function adaptiveinterpolate(
313355
verbosity,
314356
ntry=10,
315357
initialpivots=initialpivots,
358+
recyclepivots=recyclepivots,
316359
)
317360
tmp = adaptiveinterpolate(creator, pordering; verbosity)
318361
return reshape(tmp, f.sitedims)

src/patching.jl

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ abstract type AbstractPatchCreator{T,M} end
3838
mutable struct PatchCreatorResult{T,M}
3939
data::Union{M,Nothing}
4040
isconverged::Bool
41+
resultpivots::Vector{MultiIndex}
42+
43+
function PatchCreatorResult{T,M}(
44+
data::Union{M,Nothing}, isconverged::Bool, resultpivots::Vector{MultiIndex}
45+
)::PatchCreatorResult{T,M} where {T,M}
46+
return new{T,M}(data, isconverged, resultpivots)
47+
end
48+
49+
function PatchCreatorResult{T,M}(
50+
data::Union{M,Nothing}, isconverged::Bool
51+
)::PatchCreatorResult{T,M} where {T,M}
52+
return new{T,M}(data, isconverged, MultiIndex[])
53+
end
4154
end
4255

4356
function _reconst_prefix(projector::Projector, pordering::PatchOrdering)
@@ -63,10 +76,19 @@ function __taskfunc(creator::AbstractPatchCreator{T,M}, pordering; verbosity=0)
6376
for ic in 1:creator.localdims[pordering.ordering[length(prefix) + 1]]
6477
prefix_ = vcat(prefix, ic)
6578
projector_ = makeproj(pordering, prefix_, creator.localdims)
66-
#if verbosity > 0
67-
##println("Creating a task for $(prefix_) ...")
68-
#end
69-
push!(newtasks, project(creator, projector_))
79+
80+
# Pivots are shorter, pordering index is in a different position
81+
active_dims_ = findall(x -> x == [0], creator.projector.data)
82+
pos_ = findfirst(x -> x == pordering.ordering[length(prefix) + 1], active_dims_)
83+
pivots_ = [
84+
copy(piv) for piv in filter(piv -> piv[pos_] == ic, patch.resultpivots)
85+
]
86+
87+
if !isempty(pivots_)
88+
deleteat!.(pivots_, pos_)
89+
end
90+
91+
push!(newtasks, project(creator, projector_; pivots=pivots_))
7092
end
7193
return nothing, newtasks
7294
end
@@ -77,14 +99,19 @@ function _zerott(T, prefix, po::PatchOrdering, localdims::Vector{Int})
7799
return TensorTrain([zeros(T, 1, d, 1) for d in localdims_])
78100
end
79101

80-
function project(obj::AbstractPatchCreator{T,M}, projector::Projector) where {T,M}
102+
function project(
103+
obj::AbstractPatchCreator{T,M},
104+
projector::Projector;
105+
pivots::Vector{MultiIndex}=MultiIndex[],
106+
) where {T,M}
81107
projector <= obj.projector || error(
82108
"Projector $projector is not a subset of the original projector $(obj.f.projector)",
83109
)
84110

85111
obj_copy = TCI2PatchCreator{T}(obj) # shallow copy
86112
obj_copy.projector = deepcopy(projector)
87113
obj_copy.f = project(obj_copy.f, projector)
114+
obj_copy.initialpivots = deepcopy(pivots)
88115
return obj_copy
89116
end
90117

test/crossinterpolate_tests.jl

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ using Random
8989
ntry=10,
9090
)
9191

92-
obj = TCIA.adaptiveinterpolate(creator, pordering; verbosity=2)
92+
obj = TCIA.adaptiveinterpolate(creator, pordering; verbosity=0)
9393

9494
points = [(rand() * 10 - 5, rand() * 10 - 5) for i in 1:100]
9595

@@ -140,7 +140,7 @@ using Random
140140

141141
ptt = TCIA.project(TCIA.ProjTensorTrain(tt), p)
142142

143-
obj = TCIA.adaptiveinterpolate(ptt; verbosity=1, maxbonddim=5)
143+
obj = TCIA.adaptiveinterpolate(ptt; verbosity=0, maxbonddim=5)
144144

145145
@test vec(TCIA.fulltensor(obj)) vec(TCIA.fulltensor(ptt))
146146
end
@@ -175,9 +175,9 @@ using Random
175175
verbosity=0,
176176
ntry=10,
177177
)
178-
obj = TCIA.adaptiveinterpolate(creator, pordering; verbosity=2)
178+
obj = TCIA.adaptiveinterpolate(creator, pordering; verbosity=0)
179179

180-
points = [(rand() * 2*pi - pi, rand() * 2*pi - pi) for i in 1:1000]
180+
points = [(rand() * 2 * pi - pi, rand() * 2 * pi - pi) for i in 1:1000]
181181

182182
@assert length(obj) > localdims[1]
183183

@@ -186,6 +186,85 @@ using Random
186186
[qf(QG.origcoord_to_quantics(grid, p)) for p in points];
187187
atol=1e-4,
188188
)
189+
end
190+
191+
@testset "recyclepivots" begin
192+
Random.seed!(1234)
193+
194+
function linear_gaussians(
195+
x::Float64, y::Float64; centers::Vector{Vector{Float64}}=[[0.0, 0.0]]
196+
)
197+
N_gauss = length(centers)
198+
input_vec = [x, y]
199+
σ = [2.0^(-j - 1) for j in 1:N_gauss]
200+
return sum(exp(-norm(input_vec - centers[j])^2 / σ[j]^2) for j in 1:N_gauss)
201+
end
202+
203+
R = 30
204+
grid = QG.DiscretizedGrid{2}(R, (-3.0, -3.0), (3.0, 3.0))
205+
localdims = fill(4, R)
206+
sitedims = [[2, 2] for _ in 1:R]
207+
208+
gauss_centers = [[-2.0, 2.0], [2.0, -2.0], [-2.0, -2.0], [2.0, 2.0]]
209+
qf =
210+
x -> linear_gaussians(
211+
QG.quantics_to_origcoord(grid, x)...; centers=gauss_centers
212+
)
213+
214+
pordering = TCIA.PatchOrdering(collect(1:R))
215+
tol = 1e-7
216+
mb = 20
217+
218+
# Provide very good pivots on the maximums of the function
219+
initialpivots = [
220+
TCI.optfirstpivot(qf, localdims, QG.origcoord_to_quantics(grid, Tuple(c))) for
221+
c in gauss_centers
222+
]
223+
224+
creator_recycle = TCIA.TCI2PatchCreator(
225+
Float64,
226+
TCIA.makeprojectable(Float64, qf, localdims),
227+
localdims,
228+
;
229+
maxbonddim=mb,
230+
tolerance=tol,
231+
recyclepivots=true,
232+
initialpivots=initialpivots,
233+
ninitialpivot=0,
234+
)
189235

236+
creator_no_recycle = TCIA.TCI2PatchCreator(
237+
Float64,
238+
TCIA.makeprojectable(Float64, qf, localdims),
239+
localdims,
240+
;
241+
maxbonddim=mb,
242+
tolerance=tol,
243+
recyclepivots=false,
244+
initialpivots=initialpivots,
245+
ninitialpivot=0,
246+
)
247+
248+
# Perform ony one patch iteration
249+
_, recycle_patches = TCIA.__taskfunc(creator_recycle, pordering)
250+
_, no_recycle_patches = TCIA.__taskfunc(creator_no_recycle, pordering)
251+
252+
# Check if the "good" pivots are conserved in each patch
253+
for patch in recycle_patches
254+
prj_val = collect(Iterators.flatten(patch.projector))[1]
255+
patch_pivots = vcat.([prj_val], patch.initialpivots)
256+
257+
@test intersect(patch_pivots, initialpivots) ==
258+
filter(piv -> piv[1] == prj_val, initialpivots)
259+
end
260+
261+
# Check that the information is lost
262+
for patch in no_recycle_patches
263+
prj_val = collect(Iterators.flatten(patch.projector))[1]
264+
patch_pivots = vcat.([prj_val], patch.initialpivots)
265+
266+
@test intersect(patch_pivots, initialpivots) !==
267+
filter(piv -> piv[1] == prj_val, initialpivots)
268+
end
190269
end
191270
end

0 commit comments

Comments
 (0)