Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5f8ab76
Add average and weightedsum functions
Sep 11, 2024
8064e3d
Distributed implementation of site2sweep
Sep 20, 2024
d10e2cc
Synch with main and improve documentation
Sep 20, 2024
bde32e1
Update Project.toml
simone-fdr Sep 20, 2024
77764c0
Implement non-homogeneous distribution
Nov 12, 2024
fc230bf
Add intra-bond parallelism
Dec 21, 2024
a7d0d58
create branch and optional SVD,LU,CI
Jan 22, 2025
35d00b0
Implement RSVD and RRRSVD
Jan 27, 2025
dee8a29
Improve RSVD and R3SVD
Feb 4, 2025
1fe0957
Better r3svd and power method
Feb 5, 2025
c9142a7
Implement distributed zipup
Apr 9, 2025
928fed5
merge with main
Apr 9, 2025
960f20b
minor bug fix
Apr 14, 2025
25c0637
Draft code for FIT and compatibility with patches
May 5, 2025
6d1c0fd
add unpolished version of fit and distributed fit algorithms
Jun 20, 2025
3ffc5b3
pass Aqua and JET tests
Jun 20, 2025
d307440
add test_mpi
Jun 20, 2025
16adc44
add unit tests and documentation
Jun 26, 2025
c03b77e
small fix before merging
Jun 26, 2025
95d5410
solve merge conflict
Jun 26, 2025
4c23bcf
remove old TODOs
Jun 26, 2025
174be98
fixed documentation.md
Jun 26, 2025
76f285e
fit and distrfit improvement. minor function creation
Aug 1, 2025
e9a8831
merge with main
Aug 1, 2025
65c07cd
Fixed documentation
Aug 1, 2025
00ac23a
Fixed type inference error
Aug 1, 2025
3b371d5
Fixed documentation (moved a comment line)
Aug 1, 2025
4e3e26e
better input-output management
Sep 11, 2025
be46504
fixed typo
Sep 11, 2025
2fe893b
branch ready to be pruned
Nov 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
*.jl.mem
/Manifest.toml
/docs/build/
LocalPreferences.toml
.vscode/settings.json
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.9.17"
BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

Expand All @@ -16,6 +17,7 @@ EllipsisNotation = "1"
QuadGK = "2.9"
Random = "1.10.0"
julia = "1.6"
MPI = "0.20.22"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
5 changes: 5 additions & 0 deletions docs/src/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,8 @@ Modules = [TensorCrossInterpolation]
Pages = ["cachedfunction.jl", "batcheval.jl", "util.jl", "globalsearch.jl"]
```

## Parallel utility
```@autodocs
Modules = [TensorCrossInterpolation]
Pages = ["mpi.jl"]
```
5 changes: 4 additions & 1 deletion src/TensorCrossInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ module TensorCrossInterpolation
using LinearAlgebra
using EllipsisNotation
using BitIntegers
import QuadGK
using MPI
using Base.Threads

import QuadGK
# To add a method for rank(tci)
import LinearAlgebra: rank, diag
import LinearAlgebra as LA
Expand Down Expand Up @@ -40,5 +42,6 @@ include("conversion.jl")
include("integration.jl")
include("contraction.jl")
include("globalsearch.jl")
include("mpi.jl")

end
291 changes: 287 additions & 4 deletions src/abstracttensortrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ function linkdims(tt::AbstractTensorTrain{V})::Vector{Int} where {V}
return [size(T, 1) for T in tt[2:end]]
end

"""
function linkdims(tt::AbstractTensorTrain{V})::Vector{Int} where {V}

Bond dimensions along the links between ``T`` tensors in the tensor train.

See also: [`rank`](@ref)
"""
function linkdims(tt::Vector{Array{V, N}})::Vector{Int} where {V, N}
return [size(T, 1) for T in tt[2:end]]
end

"""
function linkdim(tt::AbstractTensorTrain{V}, i::Int)::Int where {V}

Expand Down Expand Up @@ -131,6 +142,26 @@ function evaluate(
return only(prod(T[:, i, :] for (T, i) in zip(tt, indexset)))
end

"""
function evaluate(
tt::TensorTrain{V},
indexset::Union{AbstractVector{LocalIndex}, NTuple{N, LocalIndex}}
)::V where {N, V}

Evaluates the tensor train `tt` at indices given by `indexset` and `jndexset`. This is ment to be used for MPOs.
"""
function evaluate(
tt::AbstractTensorTrain{V},
indexset::Union{AbstractVector{LocalIndex},NTuple{N,LocalIndex}},
jndexset::Union{AbstractVector{LocalIndex},NTuple{N,LocalIndex}}
)::V where {N,V}
if length(indexset) != length(tt)
throw(ArgumentError("To evaluate a tt of length $(length(tt)), you have to provide $(length(tt)) indices, but there were $(length(indexset))."))
end
return only(prod(T[:, i, j, :] for (T, i, j) in zip(tt, indexset, jndexset)))
end


"""
function evaluate(tt::TensorTrain{V}, indexset::CartesianIndex) where {V}

Expand Down Expand Up @@ -175,6 +206,38 @@ function sum(tt::AbstractTensorTrain{V}) where {V}
return only(v)
end

"""
function average(tt::TensorTrain{V}) where {V}

Evaluates the average of the tensor train approximation over all lattice sites in an efficient
factorized manner.
"""
function average(tt::AbstractTensorTrain{V}) where {V}
v = transpose(sum(tt[1], dims=(1, 2))[1, 1, :]) / length(tt[1][1, :, 1])
for T in tt[2:end]
v *= sum(T, dims=2)[:, 1, :] / length(T[1, :, 1])
end
return only(v)
end

"""
function weightedsum(tt::TensorTrain{V}, w::Vector{V}) where {V}

Evaluates the weighted sum of the tensor train approximation over all lattice sites in an efficient
factorized manner, where w is the vector of vector of weights which has the same length and the same sizes as tt.
"""
function weightedsum(tt::AbstractTensorTrain{V}, w::Vector{Vector{V}}) where {V}
length(tt) == length(w) || throw(DimensionMismatch("The length of the Tensor Train is different from the one of the weight vector ($(length(tt)) and $(length(w)))."))
size(tt[1])[2] == length(w[1]) || throw(DimensionMismatch("The dimension at site 1 of the Tensor Train is different from the one of the weight vector ($(size(tt[1])[2]) and $(length(w[1])))."))
v = transpose(sum(tt[1].*w[1]', dims=(1, 2))[1, 1, :])
for i in 2:length(tt)
size(tt[i])[2] == length(w[i]) || throw(DimensionMismatch("The dimension at site $(i) of the Tensor Train is different from the one of the weight vector ($(size(tt[i])[2]) and $(length(w[i])))."))
v *= sum(tt[i].*w[i]', dims=2)[:, 1, :]
end
return only(v)
end


function _addtttensor(
A::Array{V}, B::Array{V};
factorA=one(V), factorB=one(V),
Expand Down Expand Up @@ -215,7 +278,7 @@ See also: [`+`](@ref)
function add(
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
factorlhs=one(V), factorrhs=one(V),
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int), normalizeerror::Bool=true
) where {V}
if length(lhs) != length(rhs)
throw(DimensionMismatch("Two tensor trains with different length ($(length(lhs)) and $(length(rhs))) cannot be added elementwise."))
Expand All @@ -233,7 +296,7 @@ function add(
for ell in 1:L
]
)
compress!(tt, :SVD; tolerance, maxbonddim)
compress!(tt, :SVD; tolerance, maxbonddim, normalizeerror)
return tt
end

Expand All @@ -247,9 +310,9 @@ Subtract two tensor trains `lhs` and `rhs`. See [`add`](@ref).
"""
function subtract(
lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V};
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int)
tolerance::Float64=0.0, maxbonddim::Int=typemax(Int), normalizeerror::Bool=true
) where {V}
return add(lhs, rhs; factorrhs=-1 * one(V), tolerance, maxbonddim)
return add(lhs, rhs; factorrhs=-1 * one(V), tolerance, maxbonddim, normalizeerror)
end

@doc raw"""
Expand All @@ -270,6 +333,226 @@ function Base.:-(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where
return subtract(lhs, rhs)
end

function leftcanonicalize!(tt::AbstractTensorTrain{ValueType}) where {ValueType}
n = length(tt) # Number of sites
for i in 1:n-1
Q, R = qr(reshape(tt[i], prod(size(tt[i])[1:end-1]), size(tt[i])[end]))
Q = Matrix(Q)

tt[i] = reshape(Q, size(tt[i])[1:end-1]..., size(Q, 2)) # New bond dimension after Q

tmptt = reshape(tt[i+1], size(R, 2), :) # Reshape next tensor
tmptt .= Matrix(R) * tmptt
tt[i+1] = reshape(tmptt, size(tt[i+1])...) # Reshape back
end
end

# This creates a TensorTrain which has every site right-canonical except the last
function rightcanonicalize!(tt::AbstractTensorTrain{ValueType}) where {ValueType}
n = length(tt) # Number of sites
for i in n:-1:2
# Reshape W_i into a matrix (merging right bond and physical indices)
W = tt[i]
χl, d1, d2, χr = size(W)
W_mat = reshape(W, χl, d1*d2*χr)

# Perform RQ decottsition: W_mat = R * Q
F = lq(reverse(W_mat, dims=1))
R, Q = reverse(F.L), reverse(Matrix(F.Q), dims=1) # https://discourse.julialang.org/t/rq-decomposition/112795/13

# Reshape Q back into the MPO tensor
tt[i] = reshape(Q, size(Q, 1), d1, d2, χr) # New bond dimension after Q

# Update the previous tt tensor by absorbing R
tmptt = reshape(tt[i-1], :, size(R, 1)) # Reshape previous tensor
tmptt .= tmptt * Matrix(R)
tt[i-1] = reshape(tmptt, size(tt[i-1], 1), d1, d2, size(tt[i-1], 4)) # Reshape back
end
end

function extract_vidal(tt::Vector{Array{ValueType, N}}) where {ValueType, N}
n = length(tt)
V = Vector{Array{ValueType, 2}}(undef, n-1)
V_1 = Vector{Array{ValueType, 2}}(undef, n-1)
A = deepcopy(tt)

# TODO "centercanonicalize!(A, n)"
for i in 1:n-1
Q, R = qr(reshape(A[i], prod(size(A[i])[1:3]), size(A[i])[end]))
A[i] = reshape(Matrix(Q), size(A[i])...)
A[i+1] = _contract(Matrix(R), A[i+1], (2,), (1,))
end

for i in n:-1:2
left, diamond, right, _, _ = _factorize(
reshape(A[i], size(A[i])[1], prod(size(A[i])[2:4])),
:SVD; tolerance=0.0, maxbonddim=size(A[i])[1], diamond=true
)

V[i-1] = Diagonal(diamond)
V_1[i-1] = Diagonal(diamond.^-1)

A[i] = reshape(right, size(A[i])...)
A[i-1] = _contract(A[i-1], left*V[i-1], (4,), (1,))
end

for i in 1:n-1
A[i] = _contract(A[i], V_1[i], (4,), (1,))
end

return A, V
end

function vidal_to_inv(G::Vector{Array{ValueType, 4}}, L::Vector{Matrix{ValueType}}) where {ValueType}
n = length(G)
Psi = Vector{Array{ValueType, 4}}(undef, n)
V = Vector{Array{ValueType, 2}}(undef, n)

Psi[1] = _contract(G[1], L[1], (4,), (1,))
for i in 2:n-1
Psi[i] = _contract(L[i-1], G[i], (2,), (1,))
Psi[i] = _contract(Psi[i], L[i], (4,), (1,))
end
Psi[n] = _contract(L[n-1], G[n], (2,), (1,))

for i in 1:n-1
V[i] = Diagonal(diag(L[i]).^-1)
end

return Psi, V
end


function centercanonicalize!(tt::Vector{Array{ValueType, N}}, center::Int; old_center::Int=0) where {ValueType, N}
orthogonality = checkorthogonality(tt)
n = length(tt) # Number of sites

if count(==( :N ), orthogonality) == 1
old_center_ = findfirst(==( :N ), orthogonality)
if old_center_ == nothing # Useless, but help JET compiling
old_center_ = old_center
end
# println("Sto canonicalizzando centrando in $center. ho trovato il centro in $old_center_. Quindi flipperò: $(center < old_center_ ? [size(tt[i]) for i in center:old_center_] : [size(tt[i]) for i in old_center_:center])")
if old_center != 0 && old_center != old_center_
println("Warning! In centercanonicalize!() old_center has been set as $old_center, but the real old center is $old_center_")
end
elseif old_center == 0
old_center_ = 1
else
old_center_ = old_center
end
# LEFT
for i in old_center_:center-1
Q, R = qr(reshape(tt[i], prod(size(tt[i])[1:end-1]), size(tt[i])[end]))
Q = Matrix(Q)

tt[i] = reshape(Q, size(tt[i])[1:end-1]..., size(Q, 2)) # New bond dimension after Q

tmptt = reshape(tt[i+1], size(R, 2), :) # Reshape next tensor
tmptt = Matrix(R) * tmptt
tt[i+1] = reshape(tmptt, size(tt[i+1])...) # Reshape back
end
# RIGHT
if count(==( :N ), orthogonality) == 1
old_center_ = findfirst(==( :N ), orthogonality)
if old_center_ == nothing # Useless, but help JET compiling
old_center_ = old_center
end
if old_center != 0 && old_center != old_center_
println("Warning! In centercanonicalize!() old_center has been set as $old_center, but the real old center is $old_center_")
end
elseif old_center == 0
old_center_ = n
else
old_center_ = old_center
end
for i in old_center_:-1:center+1
W = tt[i]
χl, d1, d2, χr = size(W)
W_mat = reshape(W, χl, d1*d2*χr)

L, Q = lq(W_mat)
Q = Matrix(Q)
# Reshape Q back into the tt tensor
tt[i] = reshape(Q, size(Q, 1), d1, d2, χr) # New bond dimension after Q

# Update the previous tt tensor by absorbing L
tmptt = reshape(tt[i-1], :, size(L, 1)) # Reshape previous tensor
tmptt = tmptt * Matrix(L)
tt[i-1] = reshape(tmptt, size(tt[i-1], 1), d1, d2, size(tmptt, 2)) # Reshape back
end
end

function move_center_right!(tt, i)
A = tt[i]
d = size(A)
A_mat = reshape(A, prod(d[1:end-1]), d[end])
Q, R = qr(A_mat)
Q = Matrix(Q)
tt[i] = reshape(Q, d[1:end-1]..., size(Q, 2))

B = tt[i+1]
B_mat = reshape(B, size(R, 2), :)
B_mat .= Matrix(R) * B_mat
tt[i+1] = reshape(B_mat, size(B)...)
end

function move_center_left!(tt, i)
A = tt[i]
d = size(A)
A_mat = reshape(A, d[1], prod(d[2:end]))
L, Q = lq(A_mat)
Q = Matrix(Q)
tt[i] = reshape(Q, size(Q,1), d[2:end]...)

B = tt[i-1]
B_mat = reshape(B, :, size(L,1))
B_mat .= B_mat * Matrix(L)
tt[i-1] = reshape(B_mat, size(B)[1:3]..., size(L,1))
end


function leftcanonicalize(tt::AbstractTensorTrain{ValueType}) where {ValueType}
tt_ = deepcopy(tt)
leftcanonicalize!(tt_)
return tt_
end

# This creates a TensorTrain which has every site right-canonical except the last
function rightcanonicalize(tt::AbstractTensorTrain{ValueType}) where {ValueType}
tt_ = deepcopy(tt)
rightcanonicalize!(tt_)
return tt_
end

# This creates a TensorTrain which has every site right-canonical except the last
function centercanonicalize(tt::Vector{Array{ValueType, N}}, center::Int; old_center::Int=0) where {ValueType, N}
tt_ = deepcopy(tt)
centercanonicalize!(tt_, center; old_center)
return tt_
end

function checkorthogonality(tt::Vector{Array{ValueType, N}}) where {ValueType, N}
ort = Vector{Symbol}(undef, length(tt))
for i in 1:length(tt)
W = tt[i]
left_check = _contract(permutedims(W, (4,2,3,1,)), W, (2,3,4,),(2,3,1))
right_check = _contract(W, permutedims(W, (4,2,3,1,)), (2,3,4,),(2,3,1))
is_left = isapprox(left_check, I, atol=1e-7)
is_right = isapprox(right_check, I, atol=1e-7)
ort[i] = if is_left && is_right
:O # Orthogonal
elseif is_left
:L # Left orthogonal
elseif is_right
:R # Right orthogonal
else
:N # Non orthogonal
end
end
return ort
end

"""
Squared Frobenius norm of a tensor train.
"""
Expand Down
Loading