Skip to content

Commit 24cffee

Browse files
authored
Merge pull request #18 from tensor4all/17-improvements-for-compress-function-using-svd
17 improvements for compress function using svd
2 parents b849603 + a3cce81 commit 24cffee

File tree

4 files changed

+100
-13
lines changed

4 files changed

+100
-13
lines changed

src/abstracttensortrain.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,25 @@ Subtraction of two tensor trains. If `c = a - b`, then `c(v) ≈ a(v) - b(v)` at
269269
function Base.:-(lhs::AbstractTensorTrain{V}, rhs::AbstractTensorTrain{V}) where {V}
270270
return subtract(lhs, rhs)
271271
end
272+
273+
"""
274+
Squared Frobenius norm of a tensor train.
275+
"""
276+
function LA.norm2(tt::AbstractTensorTrain{V})::Float64 where {V}
277+
function _f(n)::Matrix{V}
278+
t = sitetensor(tt, n)
279+
t3 = reshape(t, size(t)[1], :, size(t)[end])
280+
# (lc, s, rc) * (l, s, r) => (lc, rc, l, r)
281+
tct = _contract(conj.(t3), t3, (2,), (2,))
282+
tct = permutedims(tct, (1, 3, 2, 4))
283+
return reshape(tct, size(tct, 1) * size(tct, 2), size(tct, 3) * size(tct, 4))
284+
end
285+
return real(only(reduce(*, (_f(n) for n in 1:length(tt)))))
286+
end
287+
288+
"""
289+
Frobenius norm of a tensor train.
290+
"""
291+
function LA.norm(tt::AbstractTensorTrain{V})::Float64 where {V}
292+
sqrt(LA.norm2(tt))
293+
end

src/tensortrain.jl

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,25 +93,44 @@ function tensortrain(tci)
9393
end
9494

9595
function _factorize(
96-
A::AbstractMatrix{V}, method::Symbol; tolerance::Float64, maxbonddim::Int
96+
A::AbstractMatrix{V}, method::Symbol; tolerance::Float64, maxbonddim::Int, leftorthogonal::Bool=false, normalizeerror=true
9797
)::Tuple{Matrix{V},Matrix{V},Int} where {V}
98+
reltol = 1e-14
99+
abstol = 0.0
100+
if normalizeerror
101+
reltol = tolerance
102+
else
103+
abstol = tolerance
104+
end
98105
if method === :LU
99-
factorization = rrlu(A, abstol=tolerance, maxrank=maxbonddim)
106+
factorization = rrlu(A, abstol=abstol, reltol=reltol, maxrank=maxbonddim, leftorthogonal=leftorthogonal)
100107
return left(factorization), right(factorization), npivots(factorization)
101108
elseif method === :CI
102-
factorization = MatrixLUCI(A, abstol=tolerance, maxrank=maxbonddim)
109+
factorization = MatrixLUCI(A, abstol=abstol, reltol=reltol, maxrank=maxbonddim, leftorthogonal=leftorthogonal)
103110
return left(factorization), right(factorization), npivots(factorization)
104111
elseif method === :SVD
105112
factorization = LinearAlgebra.svd(A)
113+
err = [sum(factorization.S[n+1:end] .^ 2) for n in 1:length(factorization.S)]
114+
normalized_err = err ./ sum(factorization.S .^ 2)
115+
106116
trunci = min(
107-
replacenothing(findlast(>(tolerance), factorization.S), 1),
117+
replacenothing(findfirst(<(abstol^2), err), length(err)),
118+
replacenothing(findfirst(<(reltol^2), normalized_err), length(normalized_err)),
108119
maxbonddim
109120
)
110-
return (
111-
factorization.U[:, 1:trunci],
112-
Diagonal(factorization.S[1:trunci]) * factorization.Vt[1:trunci, :],
113-
trunci
114-
)
121+
if leftorthogonal
122+
return (
123+
factorization.U[:, 1:trunci],
124+
Diagonal(factorization.S[1:trunci]) * factorization.Vt[1:trunci, :],
125+
trunci
126+
)
127+
else
128+
return (
129+
factorization.U[:, 1:trunci] * Diagonal(factorization.S[1:trunci]),
130+
factorization.Vt[1:trunci, :],
131+
trunci
132+
)
133+
end
115134
else
116135
error("Not implemented yet.")
117136
end
@@ -131,25 +150,28 @@ function compress!(
131150
tt::TensorTrain{V,N},
132151
method::Symbol=:LU;
133152
tolerance::Float64=1e-12,
134-
maxbonddim::Int=typemax(Int)
153+
maxbonddim::Int=typemax(Int),
154+
normalizeerror::Bool=true
135155
) where {V,N}
156+
# From left to right
136157
for ell in 1:length(tt)-1
137158
shapel = size(tt.sitetensors[ell])
138159
left, right, newbonddim = _factorize(
139160
reshape(tt.sitetensors[ell], prod(shapel[1:end-1]), shapel[end]),
140-
method; tolerance, maxbonddim
161+
method; tolerance=0.0, maxbonddim=typemax(Int), leftorthogonal=true # no truncation
141162
)
142163
tt.sitetensors[ell] = reshape(left, shapel[1:end-1]..., newbonddim)
143164
shaper = size(tt.sitetensors[ell+1])
144165
nexttensor = right * reshape(tt.sitetensors[ell+1], shaper[1], prod(shaper[2:end]))
145166
tt.sitetensors[ell+1] = reshape(nexttensor, newbonddim, shaper[2:end]...)
146167
end
147168

169+
# From right to left
148170
for ell in length(tt):-1:2
149171
shaper = size(tt.sitetensors[ell])
150172
left, right, newbonddim = _factorize(
151173
reshape(tt.sitetensors[ell], shaper[1], prod(shaper[2:end])),
152-
method; tolerance, maxbonddim
174+
method; tolerance, maxbonddim, normalizeerror, leftorthogonal=false
153175
)
154176
tt.sitetensors[ell] = reshape(right, newbonddim, shaper[2:end]...)
155177
shapel = size(tt.sitetensors[ell-1])
@@ -212,6 +234,7 @@ function Base.reverse(tt::AbstractTensorTrain{V}) where {V}
212234
]))
213235
end
214236

237+
215238
"""
216239
Fitting data with a TensorTrain object.
217240
This may be useful when the interpolated function is noisy.
@@ -266,4 +289,4 @@ function fulltensor(obj::TensorTrain{T,N})::Array{T} where {T,N}
266289
end
267290
returnsize = collect(Iterators.flatten(sitedims_))
268291
return reshape(result, returnsize...)
269-
end
292+
end

test/test_blockstructure.jl

Whitespace-only changes.

test/test_tensortrain.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import TensorCrossInterpolation as TCI
2+
import LinearAlgebra as LA
23
using Random
34
using Zygote
45
using Optim
@@ -196,6 +197,47 @@ end
196197
@test ttmultileg2.(indicesmultileg) 2 .* ttmultileg.(indicesmultileg)
197198
end
198199

200+
201+
@testset "norm" begin
202+
T = Float64
203+
sitedims_ = [[2], [2], [2]]
204+
N = length(sitedims_)
205+
bonddims = [1, 1, 1, 1]
206+
207+
tt = TCI.TensorTrain([
208+
ones(bonddims[n], sitedims_[n]..., bonddims[n+1]) for n in 1:N
209+
])
210+
211+
@test LA.norm2(tt) prod(only.(sitedims_))
212+
@test LA.norm2(2 * tt) 4 * prod(only.(sitedims_))
213+
@test LA.norm2(tt) LA.norm(tt)^2
214+
end
215+
216+
@testset "compress! (SVD)" for T in [Float64, ComplexF64]
217+
Random.seed!(1234)
218+
T = Float64
219+
N = 10
220+
sitedims_ = [[2] for _ in 1:N]
221+
χ = 10
222+
223+
tol = 0.1
224+
bonddims = vcat(1, χ * ones(Int, N - 1), 1)
225+
226+
tt = TCI.TensorTrain([
227+
randn(bonddims[n], sitedims_[n]..., bonddims[n+1]) for n in 1:N
228+
])
229+
230+
# normalizeerror=true
231+
tt_compressed = deepcopy(tt)
232+
TCI.compress!(tt_compressed, :SVD; tolerance=tol)
233+
@test sqrt(LA.norm2(tt - tt_compressed) / LA.norm2(tt)) < sqrt(N) * tol
234+
235+
# normalizeerror=false
236+
tt_compressed = deepcopy(tt)
237+
TCI.compress!(tt_compressed, :SVD; tolerance=LA.norm(tt) * tol, normalizeerror=false)
238+
@test sqrt(LA.norm2(tt - tt_compressed) / LA.norm2(tt)) < sqrt(N) * tol
239+
end
240+
199241
@testset "tensor train cast" begin
200242
Random.seed!(10)
201243
localdims = [2, 2, 2]

0 commit comments

Comments
 (0)