@@ -122,6 +122,7 @@ mutable struct TCI2PatchCreator{T} <: AbstractPatchCreator{T,TensorTrainState{T}
122122 checkbatchevaluatable:: Bool
123123 loginterval:: Int
124124 initialpivots:: Vector{MultiIndex} # Make it to Vector{MMultiIndex}?
125+ recyclepivots:: Bool
125126end
126127
127128function Base. show (io:: IO , obj:: TCI2PatchCreator{T} ) where {T}
@@ -141,6 +142,7 @@ function TCI2PatchCreator{T}(obj::TCI2PatchCreator{T})::TCI2PatchCreator{T} wher
141142 obj. checkbatchevaluatable,
142143 obj. loginterval,
143144 obj. initialpivots,
145+ obj. recyclepivots,
144146 )
145147end
146148
@@ -157,7 +159,8 @@ function TCI2PatchCreator(
157159 ninitialpivot= 5 ,
158160 checkbatchevaluatable= false ,
159161 loginterval= 10 ,
160- initialpivots= Vector{MultiIndex}[],
162+ initialpivots= MultiIndex[],
163+ recyclepivots= false ,
161164):: TCI2PatchCreator{T} where {T}
162165 # t1 = time_ns()
163166 if projector === nothing
@@ -183,6 +186,7 @@ function TCI2PatchCreator(
183186 checkbatchevaluatable,
184187 loginterval,
185188 initialpivots,
189+ recyclepivots,
186190 )
187191end
188192
@@ -206,6 +210,7 @@ function _crossinterpolate2!(
206210 verbosity:: Int = 0 ,
207211 checkbatchevaluatable= false ,
208212 loginterval= 10 ,
213+ recyclepivots= false ,
209214) where {T}
210215 ncheckhistory = 3
211216 ranks, errors = TCI. optimize! (
@@ -231,13 +236,45 @@ function _crossinterpolate2!(
231236 ncheckhistory_ = min (ncheckhistory, length (errors))
232237 maxbonddim_hist = maximum (ranks[(end - ncheckhistory_ + 1 ): end ])
233238
234- return PatchCreatorResult {T,TensorTrain{T,3}} (
235- TensorTrain (tci), TCI. maxbonderror (tci) < tolerance && maxbonddim_hist < maxbonddim
236- )
239+ if recyclepivots
240+ return PatchCreatorResult {T,TensorTrain{T,3}} (
241+ TensorTrain (tci),
242+ TCI. maxbonderror (tci) < tolerance && maxbonddim_hist < maxbonddim,
243+ _globalpivots (tci),
244+ )
245+
246+ else
247+ return PatchCreatorResult {T,TensorTrain{T,3}} (
248+ TensorTrain (tci),
249+ TCI. maxbonderror (tci) < tolerance && maxbonddim_hist < maxbonddim,
250+ )
251+ end
252+ end
253+
254+ # Generating global pivots from local ones
255+ function _globalpivots (
256+ tci:: TCI.TensorCI2{T} ; onlydiagonal= true
257+ ):: Vector{MultiIndex} where {T}
258+ Isets = tci. Iset
259+ Jsets = tci. Jset
260+ L = length (Isets)
261+ p = Set {MultiIndex} ()
262+ # Pivot matrices
263+ for bondindex in 1 : (L - 1 )
264+ if onlydiagonal
265+ for (x, y) in zip (Isets[bondindex + 1 ], Jsets[bondindex])
266+ push! (p, vcat (x, y))
267+ end
268+ else
269+ for x in Isets[bondindex + 1 ], y in Jsets[bondindex]
270+ push! (p, vcat (x, y))
271+ end
272+ end
273+ end
274+ return collect (p)
237275end
238276
239277function createpatch (obj:: TCI2PatchCreator{T} ) where {T}
240- proj = obj. projector
241278 fsubset = _FuncAdapterTCI2Subset (obj. f)
242279
243280 tci = if isapproxttavailable (obj. f)
@@ -253,21 +290,25 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
253290 end
254291 tci
255292 else
256- # Random initial pivots
257293 initialpivots = MultiIndex[]
258- let
259- mask = [! isprojectedat (proj, n) for n in 1 : length (proj)]
260- for idx in obj. initialpivots
261- idx_ = [[i] for i in idx]
262- if idx_ <= proj
263- push! (initialpivots, idx[mask])
264- end
294+ if obj. recyclepivots
295+ # First patching iteration: random pivots
296+ if length (fsubset. localdims) == length (obj. localdims)
297+ initialpivots = union (
298+ obj. initialpivots,
299+ findinitialpivots (fsubset, fsubset. localdims, obj. ninitialpivot),
300+ )
301+ # Next iterations: recycle previously generated pivots
302+ else
303+ initialpivots = copy (obj. initialpivots)
265304 end
305+ else
306+ initialpivots = union (
307+ obj. initialpivots,
308+ findinitialpivots (fsubset, fsubset. localdims, obj. ninitialpivot),
309+ )
266310 end
267- append! (
268- initialpivots,
269- findinitialpivots (fsubset, fsubset. localdims, obj. ninitialpivot),
270- )
311+
271312 if all (fsubset .(initialpivots) .== 0 )
272313 return PatchCreatorResult {T,TensorTrainState{T}} (nothing , true )
273314 end
@@ -282,6 +323,7 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
282323 verbosity= obj. verbosity,
283324 checkbatchevaluatable= obj. checkbatchevaluatable,
284325 loginterval= obj. loginterval,
326+ recyclepivots= obj. recyclepivots,
285327 )
286328end
287329
@@ -296,14 +338,14 @@ function adaptiveinterpolate(
296338end
297339
298340function adaptiveinterpolate (
299- f:: ProjectableEvaluator{T} ,
300- pordering:: PatchOrdering = PatchOrdering (collect (1 : length (f. sitedims)));
341+ f:: ProjectableEvaluator{T} ;
342+ pordering:: PatchOrdering = PatchOrdering (collect (1 : length (f. sitedims))),
301343 verbosity= 0 ,
302344 maxbonddim= typemax (Int),
303345 tolerance= 1e-8 ,
304- initialpivots= Vector{MultiIndex}[], # Make it to Vector{MMultiIndex}?
346+ initialpivots= MultiIndex[], # Make it to Vector{MMultiIndex}?
347+ recyclepivots= false ,
305348):: ProjTTContainer{T} where {T}
306- t1 = time_ns ()
307349 creator = TCI2PatchCreator (
308350 T,
309351 f,
@@ -313,6 +355,7 @@ function adaptiveinterpolate(
313355 verbosity,
314356 ntry= 10 ,
315357 initialpivots= initialpivots,
358+ recyclepivots= recyclepivots,
316359 )
317360 tmp = adaptiveinterpolate (creator, pordering; verbosity)
318361 return reshape (tmp, f. sitedims)
0 commit comments