Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 62 additions & 19 deletions src/crossinterpolate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ mutable struct TCI2PatchCreator{T} <: AbstractPatchCreator{T,TensorTrainState{T}
checkbatchevaluatable::Bool
loginterval::Int
initialpivots::Vector{MultiIndex} # Make it to Vector{MMultiIndex}?
recyclepivots::Bool
end

function Base.show(io::IO, obj::TCI2PatchCreator{T}) where {T}
Expand All @@ -141,6 +142,7 @@ function TCI2PatchCreator{T}(obj::TCI2PatchCreator{T})::TCI2PatchCreator{T} wher
obj.checkbatchevaluatable,
obj.loginterval,
obj.initialpivots,
obj.recyclepivots,
)
end

Expand All @@ -157,7 +159,8 @@ function TCI2PatchCreator(
ninitialpivot=5,
checkbatchevaluatable=false,
loginterval=10,
initialpivots=Vector{MultiIndex}[],
initialpivots=MultiIndex[],
recyclepivots=false,
)::TCI2PatchCreator{T} where {T}
#t1 = time_ns()
if projector === nothing
Expand All @@ -183,6 +186,7 @@ function TCI2PatchCreator(
checkbatchevaluatable,
loginterval,
initialpivots,
recyclepivots,
)
end

Expand All @@ -206,6 +210,7 @@ function _crossinterpolate2!(
verbosity::Int=0,
checkbatchevaluatable=false,
loginterval=10,
recyclepivots=false,
) where {T}
ncheckhistory = 3
ranks, errors = TCI.optimize!(
Expand All @@ -231,13 +236,45 @@ function _crossinterpolate2!(
ncheckhistory_ = min(ncheckhistory, length(errors))
maxbonddim_hist = maximum(ranks[(end - ncheckhistory_ + 1):end])

return PatchCreatorResult{T,TensorTrain{T,3}}(
TensorTrain(tci), TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim
)
if recyclepivots
return PatchCreatorResult{T,TensorTrain{T,3}}(
TensorTrain(tci),
TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim,
_globalpivots(tci),
)

else
return PatchCreatorResult{T,TensorTrain{T,3}}(
TensorTrain(tci),
TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim,
)
end
end

# Generating global pivots from local ones
function _globalpivots(
tci::TCI.TensorCI2{T}; onlydiagonal=true
)::Vector{MultiIndex} where {T}
Isets = tci.Iset
Jsets = tci.Jset
L = length(Isets)
p = Set{MultiIndex}()
# Pivot matrices
for bondindex in 1:(L - 1)
if onlydiagonal
for (x, y) in zip(Isets[bondindex + 1], Jsets[bondindex])
push!(p, vcat(x, y))
end
else
for x in Isets[bondindex + 1], y in Jsets[bondindex]
push!(p, vcat(x, y))
end
end
end
return collect(p)
end

function createpatch(obj::TCI2PatchCreator{T}) where {T}
proj = obj.projector
fsubset = _FuncAdapterTCI2Subset(obj.f)

tci = if isapproxttavailable(obj.f)
Expand All @@ -253,21 +290,25 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
end
tci
else
# Random initial pivots
initialpivots = MultiIndex[]
let
mask = [!isprojectedat(proj, n) for n in 1:length(proj)]
for idx in obj.initialpivots
idx_ = [[i] for i in idx]
if idx_ <= proj
push!(initialpivots, idx[mask])
end
if obj.recyclepivots
# First patching iteration: random pivots
if length(fsubset.localdims) == length(obj.localdims)
initialpivots = union(
obj.initialpivots,
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
)
# Next iterations: recycle previously generated pivots
else
initialpivots = copy(obj.initialpivots)
end
else
initialpivots = union(
obj.initialpivots,
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
)
end
append!(
initialpivots,
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
)

if all(fsubset.(initialpivots) .== 0)
return PatchCreatorResult{T,TensorTrainState{T}}(nothing, true)
end
Expand All @@ -282,6 +323,7 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
verbosity=obj.verbosity,
checkbatchevaluatable=obj.checkbatchevaluatable,
loginterval=obj.loginterval,
recyclepivots=obj.recyclepivots,
)
end

Expand All @@ -301,9 +343,9 @@ function adaptiveinterpolate(
verbosity=0,
maxbonddim=typemax(Int),
tolerance=1e-8,
initialpivots=Vector{MultiIndex}[], # Make it to Vector{MMultiIndex}?
initialpivots=MultiIndex[], # Make it to Vector{MMultiIndex}?
recyclepivots=false,
)::ProjTTContainer{T} where {T}
t1 = time_ns()
creator = TCI2PatchCreator(
T,
f,
Expand All @@ -313,6 +355,7 @@ function adaptiveinterpolate(
verbosity,
ntry=10,
initialpivots=initialpivots,
recyclepivots=recyclepivots,
)
tmp = adaptiveinterpolate(creator, pordering; verbosity)
return reshape(tmp, f.sitedims)
Expand Down
37 changes: 32 additions & 5 deletions src/patching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ abstract type AbstractPatchCreator{T,M} end
mutable struct PatchCreatorResult{T,M}
data::Union{M,Nothing}
isconverged::Bool
resultpivots::Vector{MultiIndex}

function PatchCreatorResult{T,M}(
data::Union{M,Nothing}, isconverged::Bool, resultpivots::Vector{MultiIndex}
)::PatchCreatorResult{T,M} where {T,M}
return new{T,M}(data, isconverged, resultpivots)
end

function PatchCreatorResult{T,M}(
data::Union{M,Nothing}, isconverged::Bool
)::PatchCreatorResult{T,M} where {T,M}
return new{T,M}(data, isconverged, MultiIndex[])
end
end

function _reconst_prefix(projector::Projector, pordering::PatchOrdering)
Expand All @@ -63,10 +76,19 @@ function __taskfunc(creator::AbstractPatchCreator{T,M}, pordering; verbosity=0)
for ic in 1:creator.localdims[pordering.ordering[length(prefix) + 1]]
prefix_ = vcat(prefix, ic)
projector_ = makeproj(pordering, prefix_, creator.localdims)
#if verbosity > 0
##println("Creating a task for $(prefix_) ...")
#end
push!(newtasks, project(creator, projector_))

# Pivots are shorter, pordering index is in a different position
active_dims_ = findall(x -> x == [0], creator.projector.data)
pos_ = findfirst(x -> x == pordering.ordering[length(prefix) + 1], active_dims_)
pivots_ = [
copy(piv) for piv in filter(piv -> piv[pos_] == ic, patch.resultpivots)
]

if !isempty(pivots_)
deleteat!.(pivots_, pos_)
end

push!(newtasks, project(creator, projector_; pivots=pivots_))
end
return nothing, newtasks
end
Expand All @@ -77,14 +99,19 @@ function _zerott(T, prefix, po::PatchOrdering, localdims::Vector{Int})
return TensorTrain([zeros(T, 1, d, 1) for d in localdims_])
end

function project(obj::AbstractPatchCreator{T,M}, projector::Projector) where {T,M}
function project(
obj::AbstractPatchCreator{T,M},
projector::Projector;
pivots::Vector{MultiIndex}=MultiIndex[],
) where {T,M}
projector <= obj.projector || error(
"Projector $projector is not a subset of the original projector $(obj.f.projector)",
)

obj_copy = TCI2PatchCreator{T}(obj) # shallow copy
obj_copy.projector = deepcopy(projector)
obj_copy.f = project(obj_copy.f, projector)
obj_copy.initialpivots = deepcopy(pivots)
return obj_copy
end

Expand Down
29 changes: 16 additions & 13 deletions src/projectable_evaluator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,28 +248,31 @@ function batchevaluateprj(
# Some of indices might be projected
NL = length(leftmmultiidxset[1])
NR = length(rightmmultiidxset[1])
L = length(obj)

NL + NR + M == length(obj) ||
error("Length mismatch NL: $NL, NR: $NR, M: $M, L: $(length(obj))")
NL + NR + M == L || error("Length mismatch NL: $NL, NR: $NR, M: $M, L: $(L)")

L = length(obj)
returnshape = projectedshape(obj.projector, NL + 1, L - NR)
result::Array{T,M + 2} = zeros(
T,
length(leftmmultiidxset),
prod.(obj.sitedims[(1 + NL):(L - NR)])...,
length(rightmmultiidxset),
T, length(leftmmultiidxset), returnshape..., length(rightmmultiidxset)
)
result[lmask, .., rmask] .= begin

projmask = map(
p -> p == 0 ? Colon() : p,
Iterators.flatten(obj.projector[n] for n in (1 + NL):(L - NR)),
)
slice = map(
p -> p == 0 ? Colon() : 1,
Iterators.flatten(obj.projector[n] for n in (1 + NL):(L - NR)),
)

result[lmask, slice..., rmask] .= begin
result_lrmask_multii = reshape(
result_lrmask,
size(result_lrmask)[1],
collect(Iterators.flatten(obj.sitedims[(1 + NL):(L - NR)]))...,
size(result_lrmask)[end],
)
projmask = map(
p -> p == 0 ? Colon() : p,
Iterators.flatten(obj.projector[n] for n in (1 + NL):(length(obj) - NR)),
)
) # Gianluca - this step might be not needed. I leave it for safety
result_lrmask_multii[:, projmask..., :]
end
return result
Expand Down
Loading
Loading