Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
/Manifest.toml
/docs/build/
/docs/Manifest.toml
*.log
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ license = "Apache-2.0"
ITensorMPS = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
ITensors = "0.9"
Expand All @@ -16,6 +17,8 @@ julia = "1"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
InitialValues = "22cec73e-a143-4acd-8c0c-585de75c995d"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Expand Down
10 changes: 10 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ TensorTrain(::ITensorMPS.MPO, ::Int, ::Int)

```@docs
contract
product
apply
truncate
truncate!
maxlinkdim
Expand All @@ -35,9 +37,17 @@ findsite
findsites
isortho
orthocenter
fit
lognorm
```

## Random Generators

```@docs
random_mps
random_mpo
```

## Default Parameters

```@docs
Expand Down
21 changes: 21 additions & 0 deletions ext/T4AITensorCompatChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Extension module to resolve ambiguities with ChainRulesCore and InitialValues
# This module is automatically loaded when ChainRulesCore or InitialValues are available

module T4AITensorCompatChainRulesCoreExt

using ChainRulesCore: NoTangent, AbstractThunk, NotImplemented, ZeroTangent
using InitialValues: NonspecificInitialValue, SpecificInitialValue
using ..T4AITensorCompat: TensorTrain

# Resolve ambiguities with ChainRulesCore types
Base.:+(stt::TensorTrain, ::NoTangent) = error("Cannot add ChainRulesCore.NoTangent to TensorTrain")
Base.:+(stt::TensorTrain, ::AbstractThunk) = error("Cannot add ChainRulesCore.AbstractThunk to TensorTrain")
Base.:+(stt::TensorTrain, ::NotImplemented) = error("Cannot add ChainRulesCore.NotImplemented to TensorTrain")
Base.:+(stt::TensorTrain, ::ZeroTangent) = error("Cannot add ChainRulesCore.ZeroTangent to TensorTrain")

# Resolve ambiguities with InitialValues types
Base.:+(stt::TensorTrain, ::Union{NonspecificInitialValue, SpecificInitialValue{typeof(+)}}) =
error("Cannot add InitialValues to TensorTrain")

end

8 changes: 5 additions & 3 deletions src/T4AITensorCompat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.

module T4AITensorCompat

import ITensors: ITensors, ITensor, Index, dim, uniqueinds, commoninds, uniqueind
import ITensors: ITensors, ITensor, Index, dim, uniqueinds, commoninds, uniqueind, inds
import ITensorMPS
import ITensors: Algorithm, @Algorithm_str
import LinearAlgebra
Expand All @@ -27,12 +27,15 @@ import LinearAlgebra
export TensorTrain
export MPS, MPO, AbstractMPS # Temporary aliases for migration
export contract
#export fit # Fit function for summing multiple tensor trains with coefficients
export fit # Fit function for summing multiple tensor trains with coefficients
export truncate, truncate!
export maxlinkdim, siteinds
export linkinds, linkind, findsite, findsites, isortho, orthocenter # Functions for compatibility
export default_maxdim, default_cutoff, default_nsweeps
export lognorm # Log norm function
export random_mps, random_mpo # Random tensor train generation
export product # Official API name (match ITensorMPS)
export apply # Backwards-compatible alias


include("defaults.jl")
Expand All @@ -45,6 +48,5 @@ const MPO = TensorTrain
const AbstractMPS = TensorTrain

include("contraction.jl")
include("itensormps_compat.jl") # Compatibility functions for ITensorMPS API

end
127 changes: 121 additions & 6 deletions src/contraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,49 @@ result = contract(M1, M2; alg=Algorithm"densitymatrix"(), maxdim=100)
```
"""
function contract(M1::TensorTrain, M2::TensorTrain; alg=Algorithm"fit"(), cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), nsweeps::Int=default_nsweeps(), kwargs...)::TensorTrain
M1_ = ITensorMPS.MPO(M1)
M2_ = ITensorMPS.MPO(M2)
# Detect MPS-like vs MPO-like by counting physical indices per site
is_mps1 = begin
sites_per_tensor = siteinds(M1)
length(M1) > 0 && all(length(s) == 1 for s in sites_per_tensor)
end
is_mps2 = begin
sites_per_tensor = siteinds(M2)
length(M2) > 0 && all(length(s) == 1 for s in sites_per_tensor)
end

# Convert to ITensorMPS types based on detected type
M1_ = is_mps1 ? ITensorMPS.MPS(M1) : ITensorMPS.MPO(M1)
M2_ = is_mps2 ? ITensorMPS.MPS(M2) : ITensorMPS.MPO(M2)

alg = Algorithm(alg)

# Handle MPO * MPS case
if !is_mps1 && is_mps2
# MPO * MPS: use ITensorMPS.contract
# Only pass nsweeps for fit algorithm, otherwise remove it from kwargs
if alg == Algorithm"fit"()
result = ITensorMPS.contract(M1_, M2_; alg=alg, cutoff=cutoff, maxdim=maxdim, nsweeps=nsweeps, kwargs...)
else
# Remove nsweeps from kwargs for non-fit algorithms
kwargs_dict = Dict(kwargs)
delete!(kwargs_dict, :nsweeps)
result = ITensorMPS.contract(M1_, M2_; alg=alg, cutoff=cutoff, maxdim=maxdim, kwargs_dict...)
end
return TensorTrain(result)
elseif is_mps1 && !is_mps2
# MPS * MPO: use ITensorMPS.contract (commutative)
# Only pass nsweeps for fit algorithm, otherwise remove it from kwargs
if alg == Algorithm"fit"()
result = ITensorMPS.contract(M2_, M1_; alg=alg, cutoff=cutoff, maxdim=maxdim, nsweeps=nsweeps, kwargs...)
else
# Remove nsweeps from kwargs for non-fit algorithms
kwargs_dict = Dict(kwargs)
delete!(kwargs_dict, :nsweeps)
result = ITensorMPS.contract(M2_, M1_; alg=alg, cutoff=cutoff, maxdim=maxdim, kwargs_dict...)
end
return TensorTrain(result)
else
# MPO * MPO: use T4AITensorCompat algorithms
if alg == Algorithm"densitymatrix"()
return TensorTrain(ContractionImpl.contract_densitymatrix(M1_, M2_; cutoff, maxdim, kwargs...))
elseif alg == Algorithm"fit"()
Expand All @@ -81,11 +121,12 @@ function contract(M1::TensorTrain, M2::TensorTrain; alg=Algorithm"fit"(), cutoff
return TensorTrain(ITensorMPS.contract(M1_, M2_; alg=Algorithm"naive"(), cutoff, maxdim, kwargs...))
else
error("Unknown algorithm: $alg")
end
end
end

"""
fit(input_states::AbstractVector{TensorTrain}, init::TensorTrain; coeffs::AbstractVector{<:Number}=ones(Int, length(input_states)), kwargs...)
fit(input_states::AbstractVector{TensorTrain}, init::TensorTrain; coeffs::AbstractVector{<:Number}=ones(Int, length(input_states)), cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), nsweeps::Int=default_nsweeps(), kwargs...)

Fit a linear combination of multiple TensorTrain objects to approximate their weighted sum.

Expand All @@ -112,9 +153,10 @@ bond dimensions while maintaining numerical accuracy.
# Fit a weighted sum of three tensor trains
result = fit([tt1, tt2, tt3], init_tt; coeffs=[1.0, 2.0, 0.5])
```
"""
#==

# Note
FIXME (HS): I observed this function is sometime less accurate than direct sum of the input states.
"""
function fit(
input_states::AbstractVector{TensorTrain},
init::TensorTrain;
Expand All @@ -124,6 +166,7 @@ function fit(
nsweeps::Int = default_nsweeps(),
kwargs...,
)::TensorTrain
println(stderr, "⚠ Warning: The `fit` function may produce less accurate results than direct sum of the input states. Consider using direct sum (`+`) if accuracy is critical.")
# Convert TensorTrain objects to ITensorMPS.MPS
mps_inputs = [ITensorMPS.MPS(tt) for tt in input_states]
mps_init = ITensorMPS.MPS(init)
Expand All @@ -134,4 +177,76 @@ function fit(
# Convert back to TensorTrain
return TensorTrain(mps_result, init.llim, init.rlim)
end
==#

"""
product(A::TensorTrain, Ψ::TensorTrain; alg=Algorithm"fit"(), cutoff=default_cutoff(), maxdim=default_maxdim(), nsweeps=default_nsweeps(), kwargs...)

Multiply an MPO with an MPS or MPO (official API name).

This function multiplies an MPO `A` by a tensor train `Ψ` (MPS or MPO),
using `contract` internally and adjusting prime levels for compatibility.

# Arguments
- `A::TensorTrain`: The MPO to apply
- `Ψ::TensorTrain`: The MPS or MPO to multiply

# Keywords
- `alg`: Algorithm specification (String or Algorithm type). Defaults to "fit".
- `cutoff::Real`: Truncation cutoff. Defaults to `default_cutoff()`.
- `maxdim::Int`: Maximum bond dimension. Defaults to `default_maxdim()`.
- `nsweeps::Int`: Number of sweeps for variational algorithms. Defaults to `default_nsweeps()`.
- `kwargs...`: Additional keyword arguments passed to contract
"""
function product(A::TensorTrain, Ψ::TensorTrain; alg=Algorithm"fit"(), cutoff=default_cutoff(), maxdim=default_maxdim(), nsweeps=default_nsweeps(), kwargs...)
if :algorithm ∈ keys(kwargs)
error("keyword argument :algorithm is not allowed")
end

# Convert alg to Algorithm type (accepts both String and Algorithm)
alg_ = alg isa Algorithm ? alg : Algorithm(alg)

# Warn if cutoff is too small for densitymatrix algorithm
if alg_ == Algorithm("densitymatrix") && cutoff <= 1e-10
@warn "cutoff is too small for densitymatrix algorithm. Use fit algorithm instead."
end

# Detect MPS-like vs MPO-like by counting physical indices per site:
# MPS tensors have 1 physical index per site, MPO tensors have 2 per site.
# This is robust to boundary tensors having fewer link indices.
is_mps_like = begin
sites_per_tensor = siteinds(Ψ)
length(Ψ) > 0 && all(length(s) == 1 for s in sites_per_tensor)
end

if is_mps_like
# Apply MPO to MPS: contract(A, ψ) then replaceprime(..., 1 => 0)
# Use T4AITensorCompat.contract for MPO * MPS
result_tt = contract(A, Ψ; alg=alg_, cutoff=cutoff, maxdim=maxdim, nsweeps=nsweeps, kwargs...)
# Adjust prime levels: replaceprime(..., 1 => 0) to get unprimed result
# Use ITensorMPS.replaceprime by converting to MPS
result_mps = ITensorMPS.MPS(result_tt)
result_mps = ITensorMPS.replaceprime(result_mps, 1 => 0)
return TensorTrain(result_mps)
else
# Apply MPO to MPO: contract(A', B) then replaceprime(..., 2 => 1)
# Use T4AITensorCompat.contract for MPO * MPO (with A' to contract over one set of indices)
A_primed = ITensors.prime(A)
result_tt = contract(A_primed, Ψ; alg=alg_, cutoff=cutoff, maxdim=maxdim, nsweeps=nsweeps, kwargs...)
# Adjust prime levels: replaceprime(..., 2 => 1) to get pairs of primed/unprimed indices
# Use ITensorMPS.replaceprime by converting to MPO
result_mpo = ITensorMPS.MPO(result_tt)
result_mpo = ITensorMPS.replaceprime(result_mpo, 2 => 1)
return TensorTrain(result_mpo)
end
end

"""
apply(A::TensorTrain, Ψ::TensorTrain; kwargs...)

Backwards-compatible alias for [`product`](@ref).

This function multiplies an MPO `A` by a tensor train `Ψ` (MPS or MPO) and
forwards all keyword arguments to [`product`](@ref).
See [`product`](@ref) for the full list of supported keywords and behavior.
"""
apply(A::TensorTrain, Ψ::TensorTrain; kwargs...) = product(A, Ψ; kwargs...)
21 changes: 21 additions & 0 deletions src/ext/T4AITensorCompatChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Extension module to resolve ambiguities with ChainRulesCore and InitialValues
# This module is automatically loaded when ChainRulesCore or InitialValues are available

module T4AITensorCompatChainRulesCoreExt

using ChainRulesCore: NoTangent, AbstractThunk, NotImplemented, ZeroTangent
using InitialValues: NonspecificInitialValue, SpecificInitialValue
using ..T4AITensorCompat: TensorTrain

# Resolve ambiguities with ChainRulesCore types
Base.:+(stt::TensorTrain, ::NoTangent) = error("Cannot add ChainRulesCore.NoTangent to TensorTrain")
Base.:+(stt::TensorTrain, ::AbstractThunk) = error("Cannot add ChainRulesCore.AbstractThunk to TensorTrain")
Base.:+(stt::TensorTrain, ::NotImplemented) = error("Cannot add ChainRulesCore.NotImplemented to TensorTrain")
Base.:+(stt::TensorTrain, ::ZeroTangent) = error("Cannot add ChainRulesCore.ZeroTangent to TensorTrain")

# Resolve ambiguities with InitialValues types
Base.:+(stt::TensorTrain, ::Union{NonspecificInitialValue, SpecificInitialValue{typeof(+)}}) =
error("Cannot add InitialValues to TensorTrain")

end

Loading
Loading