Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "T4AITensorCompat"
uuid = "864b39ca-388c-4f43-a593-a1076cf4b253"
authors = ["Hiroshi Shinaoka <[email protected]>"]
version = "0.3.0"
version = "0.4.0"
license = "Apache-2.0"

[deps]
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ findsite
findsites
isortho
orthocenter
fit
lognorm
```

## Default Parameters
Expand Down
3 changes: 2 additions & 1 deletion src/T4AITensorCompat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ import LinearAlgebra
export TensorTrain
export MPS, MPO, AbstractMPS # Temporary aliases for migration
export contract
export fit # Fit function for summing multiple tensor trains with coefficients
#export fit # Fit function for summing multiple tensor trains with coefficients
export truncate, truncate!
export maxlinkdim, siteinds
export linkinds, linkind, findsite, findsites, isortho, orthocenter # Functions for compatibility
export default_maxdim, default_cutoff, default_nsweeps
export lognorm # Log norm function


include("defaults.jl")
Expand Down
11 changes: 9 additions & 2 deletions src/contraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import ITensorMPS
import ITensors: Algorithm, @Algorithm_str
import ITensorMPS: setleftlim!, setrightlim!
import LinearAlgebra
import ..default_cutoff, ..default_maxdim, ..default_nsweeps
include("contraction/fitalgorithm.jl")
include("contraction/densitymatrix.jl")
include("contraction/fitalgorithm_sum.jl")
Expand Down Expand Up @@ -112,19 +113,25 @@ bond dimensions while maintaining numerical accuracy.
result = fit([tt1, tt2, tt3], init_tt; coeffs=[1.0, 2.0, 0.5])
```
"""
#==
FIXME (HS): I observed this function is sometime less accurate than direct sum of the input states.
function fit(
input_states::AbstractVector{TensorTrain},
init::TensorTrain;
coeffs::AbstractVector{<:Number} = ones(Int, length(input_states)),
cutoff::Real = default_cutoff(),
maxdim::Int = default_maxdim(),
nsweeps::Int = default_nsweeps(),
kwargs...,
)::TensorTrain
# Convert TensorTrain objects to ITensorMPS.MPS
mps_inputs = [ITensorMPS.MPS(tt) for tt in input_states]
mps_init = ITensorMPS.MPS(init)

# Call the fit function from ContractionImpl
mps_result = ContractionImpl.fit(mps_inputs, mps_init; coeffs=coeffs, kwargs...)
mps_result = ContractionImpl.fit(mps_inputs, mps_init; coeffs=coeffs, cutoff=cutoff, maxdim=maxdim, nsweeps=nsweeps, kwargs...)

# Convert back to TensorTrain
return TensorTrain(mps_result, init.llim, init.rlim)
end
end
==#
11 changes: 10 additions & 1 deletion src/contraction/fitalgorithm_sum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ function fit(
input_states::AbstractVector{MPS},
init::MPS;
coeffs::AbstractVector{<:Number} = ones(Int, length(input_states)),
cutoff::Real = default_cutoff(),
maxdim::Int = default_maxdim(),
nsweeps::Int = default_nsweeps(),
kwargs...,
)::MPS
links = ITensors.sim.(linkinds(init))
Expand All @@ -249,6 +252,9 @@ function fit(
reduced_operator,
init;
updater = contract_operator_state_updater,
cutoff = cutoff,
maxdim = maxdim,
nsweeps = nsweeps,
kwargs...,
)
end
Expand All @@ -257,11 +263,14 @@ function fit(
input_states::AbstractVector{MPO},
init::MPO;
coeffs::AbstractVector{<:Number} = ones(Int, length(input_states)),
cutoff::Real = default_cutoff(),
maxdim::Int = default_maxdim(),
nsweeps::Int = default_nsweeps(),
kwargs...,
)::MPO
to_mps(Ψ::MPO) = MPS([x for x in Ψ])

res = fit(to_mps.(input_states), to_mps(init); coeffs = coeffs, kwargs...)
res = fit(to_mps.(input_states), to_mps(init); coeffs = coeffs, cutoff = cutoff, maxdim = maxdim, nsweeps = nsweeps, kwargs...)
return MPO([x for x in res])
end

51 changes: 50 additions & 1 deletion src/tensortrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,52 @@ function Base.isapprox(stt1::TensorTrain, stt2::TensorTrain; kwargs...)
return ITensorMPS.isapprox(mps1, mps2; kwargs...)
end

"""
isapprox(x::TensorTrain, y::TensorTrain; atol::Real=0, rtol::Real=Base.rtoldefault(LinearAlgebra.promote_leaf_eltypes(x), LinearAlgebra.promote_leaf_eltypes(y), atol))

Check if two TensorTrain objects are approximately equal using explicit tolerance parameters.

This function computes the distance between two tensor trains and compares it against
absolute and relative tolerances.

# Arguments
- `x::TensorTrain`: First tensor train
- `y::TensorTrain`: Second tensor train

# Keyword Arguments
- `atol::Real`: Absolute tolerance (default: 0)
- `rtol::Real`: Relative tolerance (default: computed from element types)

# Returns
- `Bool`: `true` if `norm(x - y) <= max(atol, rtol * max(norm(x), norm(y)))`, `false` otherwise

# Examples
```julia
tt1 ≈ tt2 # Using default tolerances
isapprox(tt1, tt2; atol=1e-10, rtol=1e-8) # Using explicit tolerances
```
"""
function isapprox(
x::TensorTrain,
y::TensorTrain;
atol::Real = 0,
rtol::Real = Base.rtoldefault(
LinearAlgebra.promote_leaf_eltypes(x), LinearAlgebra.promote_leaf_eltypes(y), atol
),
)
d = norm(x - y)
if isfinite(d)
return d <= max(atol, rtol * max(norm(x), norm(y)))
else
error("In `isapprox(x::TensorTrain, y::TensorTrain)`, `norm(x - y)` is not finite")
end
end

# Extend LinearAlgebra.promote_leaf_eltypes for TensorTrain
function LinearAlgebra.promote_leaf_eltypes(tt::TensorTrain)
return LinearAlgebra.promote_leaf_eltypes(tt.data)
end

"""
norm(stt::TensorTrain)

Expand All @@ -297,13 +343,16 @@ Compute the log norm of a TensorTrain object.
This function delegates to ITensorMPS.lognorm for efficient computation.
The log norm is useful when the norm may be very large to avoid overflow.
"""
function ITensorMPS.lognorm(stt::TensorTrain)
function lognorm(stt::TensorTrain)
# Convert to MPS and delegate to ITensorMPS.lognorm
mps = ITensorMPS.MPS(stt)

return ITensorMPS.lognorm(mps)
end

# Also extend ITensorMPS.lognorm for backward compatibility
ITensorMPS.lognorm(stt::TensorTrain) = lognorm(stt)

"""
Subtract two TensorTrain objects using ITensors.Algorithm("directsum")

Expand Down
85 changes: 0 additions & 85 deletions test/fit_test.jl

This file was deleted.

51 changes: 51 additions & 0 deletions test/tensortrain_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,57 @@
@test abs(norm(stt1) - ITensorMPS.norm(mps_direct)) / ITensorMPS.norm(mps_direct) < 1e-13
end

@testset "TensorTrain isapprox function" begin
# Create test indices
i1 = Index(2, "i1")
i2 = Index(3, "i2")
i3 = Index(2, "i3")

# Create test tensors
t1a = random_itensor(i1, i2)
t2a = random_itensor(i2, i3)
stt1 = TensorTrain([t1a, t2a])

t1b = random_itensor(i1, i2)
t2b = random_itensor(i2, i3)
stt2 = TensorTrain([t1b, t2b])

# Test that identical tensor trains are approximately equal
@test isapprox(stt1, stt1)
@test isapprox(stt1, stt1; atol=1e-10)
@test isapprox(stt1, stt1; rtol=1e-8)

# Test that different tensor trains are not approximately equal (with strict tolerance)
@test !isapprox(stt1, stt2; atol=1e-15, rtol=1e-15)

# Test with explicit atol parameter
@test isapprox(stt1, stt1; atol=1e-10)

# Test with explicit rtol parameter
@test isapprox(stt1, stt1; rtol=1e-8)

# Test that isapprox works with Base.isapprox (≈ operator)
@test stt1 ≈ stt1
# Note: ≈ operator doesn't accept keyword arguments directly
# Use isapprox function instead for custom tolerances
@test !isapprox(stt1, stt2; atol=1e-15, rtol=1e-15)

# Test that isapprox is consistent with dist
# If distance is small, they should be approximately equal
small_diff = stt1 + 1e-12 * stt2
@test isapprox(stt1, small_diff; atol=1e-10)

# Test that isapprox handles different lengths correctly
t3 = random_itensor(i1, i2)
stt3 = TensorTrain([t3]) # Different length
@test !isapprox(stt1, stt3) # Should return false for different lengths

# Test promote_leaf_eltypes extension
using LinearAlgebra
eltype_result = LinearAlgebra.promote_leaf_eltypes(stt1)
@test eltype_result isa Type
end

@testset "Error handling for tensor train addition" begin
# Create test indices
i1 = Index(2, "i1")
Expand Down
Loading