Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ findsite
findsites
isortho
orthocenter
evaluate
fit
lognorm
```
Expand All @@ -54,4 +55,5 @@ random_mpo
default_maxdim
default_cutoff
default_nsweeps
default_abs_cutoff
```
4 changes: 2 additions & 2 deletions src/T4AITensorCompat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ export contract
export fit # Fit function for summing multiple tensor trains with coefficients
export truncate, truncate!
export maxlinkdim, siteinds
export linkinds, linkind, findsite, findsites, isortho, orthocenter # Functions for compatibility
export default_maxdim, default_cutoff, default_nsweeps
export linkinds, linkind, findsite, findsites, isortho, orthocenter, evaluate # Functions for compatibility
export default_maxdim, default_cutoff, default_nsweeps, default_abs_cutoff
export lognorm # Log norm function
export random_mps, random_mpo # Random tensor train generation
export product # Official API name (match ITensorMPS)
Expand Down
16 changes: 15 additions & 1 deletion src/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,18 @@ The default is `1`, which performs a single sweep.
# Returns
- `Int`: The default number of sweeps
"""
default_nsweeps() = 1
default_nsweeps() = 1

"""
default_abs_cutoff()

Return the default absolute cutoff threshold for truncating small singular values.

The default is `0.0`, which means no absolute cutoff is applied.
When `abs_cutoff > 0.0`, the effective cutoff becomes `cutoff * norm2 + abs_cutoff`,
where `norm2` is the squared norm of the tensor train.

# Returns
- `Float64`: The default absolute cutoff value (0.0)
"""
default_abs_cutoff() = 0.0
168 changes: 160 additions & 8 deletions src/tensortrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ end


"""
truncate!(stt::TensorTrain; cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), kwargs...)
truncate!(stt::TensorTrain; cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), abs_cutoff::Real=default_abs_cutoff(), kwargs...)

Truncate a TensorTrain in-place by removing small singular values.

Expand All @@ -820,16 +820,29 @@ applying ITensorMPS.truncate!, and updating the tensor data.
- `stt::TensorTrain`: The tensor train to truncate (modified in-place)

# Keyword Arguments
- `cutoff::Real`: Cutoff threshold for singular values (default: `default_cutoff()`)
- `cutoff::Real`: Relative cutoff threshold for singular values (default: `default_cutoff()`)
- `maxdim::Int`: Maximum bond dimension (default: `default_maxdim()`)
- `abs_cutoff::Real`: Absolute cutoff threshold (default: `default_abs_cutoff()`).
When `abs_cutoff > 0.0`, the effective cutoff becomes `cutoff + abs_cutoff/norm2`,
where `norm2` is the squared norm of the tensor train. This ensures the total error
is bounded by `cutoff * norm2 + abs_cutoff`.
- `kwargs...`: Additional keyword arguments passed to ITensorMPS.truncate!

# Returns
- `TensorTrain`: The modified tensor train (same object as input)
"""
function truncate!(stt::TensorTrain; cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), kwargs...)::TensorTrain
function truncate!(stt::TensorTrain; cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), abs_cutoff::Real=default_abs_cutoff(), kwargs...)::TensorTrain
mps = ITensorMPS.MPS(stt)
ITensorMPS.truncate!(mps; cutoff=cutoff, maxdim=maxdim, kwargs...)

# Calculate adjusted cutoff if abs_cutoff is specified
adjusted_cutoff = if abs_cutoff != 0.0
norm2 = LinearAlgebra.norm(stt)^2
cutoff + abs_cutoff / norm2
else
cutoff
end

ITensorMPS.truncate!(mps; cutoff=adjusted_cutoff, maxdim=maxdim, kwargs...)
# Update in place
for i in 1:length(stt)
stt[i] = mps[i]
Expand All @@ -838,7 +851,7 @@ function truncate!(stt::TensorTrain; cutoff::Real=default_cutoff(), maxdim::Int=
end

"""
truncate(stt::TensorTrain; cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), kwargs...)
truncate(stt::TensorTrain; cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), abs_cutoff::Real=default_abs_cutoff(), kwargs...)

Truncate a TensorTrain by removing small singular values, returning a new object.

Expand All @@ -849,16 +862,29 @@ applying ITensorMPS.truncate!, and creating a new TensorTrain from the result.
- `stt::TensorTrain`: The tensor train to truncate

# Keyword Arguments
- `cutoff::Real`: Cutoff threshold for singular values (default: `default_cutoff()`)
- `cutoff::Real`: Relative cutoff threshold for singular values (default: `default_cutoff()`)
- `maxdim::Int`: Maximum bond dimension (default: `default_maxdim()`)
- `abs_cutoff::Real`: Absolute cutoff threshold (default: `default_abs_cutoff()`).
When `abs_cutoff > 0.0`, the effective cutoff becomes `cutoff + abs_cutoff/norm2`,
where `norm2` is the squared norm of the tensor train. This ensures the total error
is bounded by `cutoff * norm2 + abs_cutoff`.
- `kwargs...`: Additional keyword arguments passed to ITensorMPS.truncate!

# Returns
- `TensorTrain`: A new truncated tensor train
"""
function truncate(stt::TensorTrain; cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), kwargs...)::TensorTrain
function truncate(stt::TensorTrain; cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), abs_cutoff::Real=default_abs_cutoff(), kwargs...)::TensorTrain
mps = ITensorMPS.MPS(stt)
ITensorMPS.truncate!(mps; cutoff=cutoff, maxdim=maxdim, kwargs...)

# Calculate adjusted cutoff if abs_cutoff is specified
adjusted_cutoff = if abs_cutoff != 0.0
norm2 = LinearAlgebra.norm(stt)^2
cutoff + abs_cutoff / norm2
else
cutoff
end

ITensorMPS.truncate!(mps; cutoff=adjusted_cutoff, maxdim=maxdim, kwargs...)
return TensorTrain(mps)
end

Expand Down Expand Up @@ -1124,6 +1150,132 @@ function orthocenter(tt::TensorTrain)
return ITensorMPS.orthocenter(mps)
end

"""
evaluate(tt::TensorTrain, sites::Vector{Index}, values::Vector{Int})

Evaluate a TensorTrain at specific site index values.

This function evaluates the tensor train by contracting it with onehot tensors
at each site corresponding to the given values.

# Arguments
- `tt::TensorTrain`: The tensor train to evaluate
- `sites::Vector{Index}`: Vector of site indices (one per tensor)
- `values::Vector{Int}`: Vector of values (one per site)

# Returns
- The scalar evaluation result

# Examples
```julia
sites = siteinds(tt)
values = [1, 2, 1]
result = evaluate(tt, sites, values)
```
"""
function evaluate(tt::TensorTrain, sites::Vector{Index}, values::Vector{Int})
length(tt) == length(sites) || error("Length mismatch: TensorTrain has $(length(tt)) tensors but $(length(sites)) site indices")
length(sites) == length(values) || error("Length mismatch: $(length(sites)) site indices but $(length(values)) values")

# Evaluate by contracting each tensor with onehot tensors
result = reduce(*, [
tt[n] * ITensors.onehot(sites[n] => values[n])
for n in 1:length(tt)
])
return only(result)
end

"""
evaluate(tt::TensorTrain, sites::Vector{Vector{Index}}, values::Vector{Int})

Evaluate a TensorTrain at specific site index values.

This function evaluates the tensor train by contracting it with onehot tensors
at each site corresponding to the given values.

# Arguments
- `tt::TensorTrain`: The tensor train to evaluate
- `sites::Vector{Vector{Index}}`: Vector of site index vectors (one per tensor)
- `values::Vector{Int}`: Vector of values (one per site)

# Returns
- The scalar evaluation result

# Examples
```julia
sites = siteinds(tt)
values = [1, 2, 1]
result = evaluate(tt, sites, values)
```
"""
function evaluate(tt::TensorTrain, sites::Vector{Vector{Index}}, values::Vector{Int})
length(tt) == length(sites) || error("Length mismatch: TensorTrain has $(length(tt)) tensors but $(length(sites)) site groups")
length(sites) == length(values) || error("Length mismatch: $(length(sites)) site groups but $(length(values)) values")

# Evaluate by contracting each tensor with onehot tensors
result = reduce(*, [
tt[n] * reduce(*, [ITensors.onehot(idx => values[n]) for idx in sites[n]])
for n in 1:length(tt)
])
return only(result)
end

"""
evaluate(tt::TensorTrain, pairs::Vector{Tuple{Index, Int}})

Evaluate a TensorTrain at specific site index values using (index, value) pairs.

This function evaluates the tensor train by contracting it with onehot tensors
for each (index, value) pair.

# Arguments
- `tt::TensorTrain`: The tensor train to evaluate
- `pairs::Vector{Tuple{Index, Int}}`: Vector of (index, value) pairs

# Returns
- The scalar evaluation result

# Examples
```julia
sites = siteinds(tt)
pairs = collect(zip(sites[1], [1, 2]))
result = evaluate(tt, pairs)
```
"""
function evaluate(tt::TensorTrain, pairs::Vector{Tuple{Index, Int}})
# Group pairs by tensor position
sites = siteinds(tt)
site_to_pos = Dict{Index, Int}()
for (pos, site_vec) in enumerate(sites)
for site in site_vec
site_to_pos[site] = pos
end
end

# Group pairs by tensor position
tensor_pairs = Dict{Int, Vector{Tuple{Index, Int}}}()
for pair in pairs
idx, val = pair
pos = get(site_to_pos, idx, nothing)
pos === nothing && error("Index $idx not found in TensorTrain")
if !haskey(tensor_pairs, pos)
tensor_pairs[pos] = Vector{Tuple{Index, Int}}()
end
push!(tensor_pairs[pos], pair)
end

# Evaluate by contracting each tensor with onehot tensors
result = reduce(*, [
if haskey(tensor_pairs, n)
tt[n] * reduce(*, [ITensors.onehot(idx => val) for (idx, val) in tensor_pairs[n]])
else
tt[n]
end
for n in 1:length(tt)
])
return only(result)
end

#===
Random tensor train generation functions
====#
Expand Down
51 changes: 50 additions & 1 deletion test/tensortrain_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testitem "tensortrain.jl" begin
include("util.jl")

import T4AITensorCompat: TensorTrain, dist, siteinds, random_mps, random_mpo
import T4AITensorCompat: TensorTrain, dist, siteinds, random_mps, random_mpo, default_abs_cutoff
import ITensors: ITensor, Index, random_itensor, dim, inds
import ITensorMPS
import ITensors: Algorithm, @Algorithm_str
Expand Down Expand Up @@ -684,6 +684,55 @@
@test relative_error(a_plus_a, a_plus_a_truncated_copy) < 1e-13
end

@testset "TensorTrain truncate with abs_cutoff" begin
# Create a simple 2-site MPS with larger bond dimension
i1 = Index(2, "i1")
i2 = Index(2, "i2")
l1 = Index(5, "Link,l1")

t1 = random_itensor(i1, l1)
t2 = random_itensor(l1, i2)

# Create TensorTrain
a = TensorTrain([t1, t2], 1, 5)

# Test that a + a doubles the bond dimension
a_plus_a = a + a
original_norm2 = norm(a_plus_a)^2
original_maxdim = T4AITensorCompat.maxlinkdim(a_plus_a)

# Test default_abs_cutoff
@test default_abs_cutoff() == 0.0

# Test truncate with abs_cutoff = 0.0 (should behave same as default)
a_truncated_default = T4AITensorCompat.truncate(a_plus_a; maxdim=10, abs_cutoff=0.0)
a_truncated_no_abs = T4AITensorCompat.truncate(a_plus_a; maxdim=10)
@test relative_error(a_truncated_default, a_truncated_no_abs) < 1e-14

# Test truncate with abs_cutoff > 0.0 (should result in more aggressive truncation)
# Use a relatively large abs_cutoff to see the effect
abs_cutoff_value = original_norm2 * 1e-6 # Small but non-negligible
a_truncated_abs = T4AITensorCompat.truncate(a_plus_a; maxdim=10, abs_cutoff=abs_cutoff_value)

# With abs_cutoff, the effective cutoff is larger, so truncation should be more aggressive
# The bond dimension should be smaller or equal
@test T4AITensorCompat.maxlinkdim(a_truncated_abs) <= T4AITensorCompat.maxlinkdim(a_truncated_no_abs)

# The error should be larger when using abs_cutoff (more aggressive truncation)
error_with_abs = relative_error(a_plus_a, a_truncated_abs)
error_without_abs = relative_error(a_plus_a, a_truncated_no_abs)
@test error_with_abs >= error_without_abs - 1e-14 # Allow small numerical differences

# Test truncate! with abs_cutoff
a_plus_a_copy = deepcopy(a_plus_a)
T4AITensorCompat.truncate!(a_plus_a_copy; maxdim=10, abs_cutoff=abs_cutoff_value)
@test relative_error(a_plus_a, a_plus_a_copy) ≈ error_with_abs atol=1e-14

# Test that abs_cutoff is only computed when != 0.0
# This is tested implicitly by the fact that the code runs without errors
# and produces correct results
end

@testset "TensorTrain siteinds function - random_mps" begin
# Test siteinds with random_mps
sites = [Index(2, "Site,n=$n") for n in 1:5]
Expand Down
Loading