diff --git a/.gitignore b/.gitignore index aabab04..c70179b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ /Manifest.toml /docs/build/ /docs/Manifest.toml +*.log diff --git a/Project.toml b/Project.toml index c1b08b4..5c67dbf 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ license = "Apache-2.0" ITensorMPS = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2" ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] ITensors = "0.9" @@ -16,6 +17,8 @@ julia = "1" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +InitialValues = "22cec73e-a143-4acd-8c0c-585de75c995d" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" diff --git a/docs/src/index.md b/docs/src/index.md index 3dbae97..cbb24e5 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -25,6 +25,8 @@ TensorTrain(::ITensorMPS.MPO, ::Int, ::Int) ```@docs contract +product +apply truncate truncate! maxlinkdim @@ -35,9 +37,17 @@ findsite findsites isortho orthocenter +fit lognorm ``` +## Random Generators + +```@docs +random_mps +random_mpo +``` + ## Default Parameters ```@docs diff --git a/ext/T4AITensorCompatChainRulesCoreExt.jl b/ext/T4AITensorCompatChainRulesCoreExt.jl new file mode 100644 index 0000000..17c7169 --- /dev/null +++ b/ext/T4AITensorCompatChainRulesCoreExt.jl @@ -0,0 +1,21 @@ +# Extension module to resolve ambiguities with ChainRulesCore and InitialValues +# This module is automatically loaded when ChainRulesCore or InitialValues are available + +module T4AITensorCompatChainRulesCoreExt + +using ChainRulesCore: NoTangent, AbstractThunk, NotImplemented, ZeroTangent +using InitialValues: NonspecificInitialValue, SpecificInitialValue +using ..T4AITensorCompat: TensorTrain + +# Resolve ambiguities with ChainRulesCore types +Base.:+(stt::TensorTrain, ::NoTangent) = error("Cannot add ChainRulesCore.NoTangent to TensorTrain") +Base.:+(stt::TensorTrain, ::AbstractThunk) = error("Cannot add ChainRulesCore.AbstractThunk to TensorTrain") +Base.:+(stt::TensorTrain, ::NotImplemented) = error("Cannot add ChainRulesCore.NotImplemented to TensorTrain") +Base.:+(stt::TensorTrain, ::ZeroTangent) = error("Cannot add ChainRulesCore.ZeroTangent to TensorTrain") + +# Resolve ambiguities with InitialValues types +Base.:+(stt::TensorTrain, ::Union{NonspecificInitialValue, SpecificInitialValue{typeof(+)}}) = + error("Cannot add InitialValues to TensorTrain") + +end + diff --git a/src/T4AITensorCompat.jl b/src/T4AITensorCompat.jl index bbd37c6..6a8281c 100644 --- a/src/T4AITensorCompat.jl +++ b/src/T4AITensorCompat.jl @@ -18,7 +18,7 @@ limitations under the License. module T4AITensorCompat -import ITensors: ITensors, ITensor, Index, dim, uniqueinds, commoninds, uniqueind +import ITensors: ITensors, ITensor, Index, dim, uniqueinds, commoninds, uniqueind, inds import ITensorMPS import ITensors: Algorithm, @Algorithm_str import LinearAlgebra @@ -27,12 +27,15 @@ import LinearAlgebra export TensorTrain export MPS, MPO, AbstractMPS # Temporary aliases for migration export contract -#export fit # Fit function for summing multiple tensor trains with coefficients +export fit # Fit function for summing multiple tensor trains with coefficients export truncate, truncate! export maxlinkdim, siteinds export linkinds, linkind, findsite, findsites, isortho, orthocenter # Functions for compatibility export default_maxdim, default_cutoff, default_nsweeps export lognorm # Log norm function +export random_mps, random_mpo # Random tensor train generation +export product # Official API name (match ITensorMPS) +export apply # Backwards-compatible alias include("defaults.jl") @@ -45,6 +48,5 @@ const MPO = TensorTrain const AbstractMPS = TensorTrain include("contraction.jl") -include("itensormps_compat.jl") # Compatibility functions for ITensorMPS API end diff --git a/src/contraction.jl b/src/contraction.jl index 7c86255..e20e3a7 100644 --- a/src/contraction.jl +++ b/src/contraction.jl @@ -68,9 +68,49 @@ result = contract(M1, M2; alg=Algorithm"densitymatrix"(), maxdim=100) ``` """ function contract(M1::TensorTrain, M2::TensorTrain; alg=Algorithm"fit"(), cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), nsweeps::Int=default_nsweeps(), kwargs...)::TensorTrain - M1_ = ITensorMPS.MPO(M1) - M2_ = ITensorMPS.MPO(M2) + # Detect MPS-like vs MPO-like by counting physical indices per site + is_mps1 = begin + sites_per_tensor = siteinds(M1) + length(M1) > 0 && all(length(s) == 1 for s in sites_per_tensor) + end + is_mps2 = begin + sites_per_tensor = siteinds(M2) + length(M2) > 0 && all(length(s) == 1 for s in sites_per_tensor) + end + + # Convert to ITensorMPS types based on detected type + M1_ = is_mps1 ? ITensorMPS.MPS(M1) : ITensorMPS.MPO(M1) + M2_ = is_mps2 ? ITensorMPS.MPS(M2) : ITensorMPS.MPO(M2) + alg = Algorithm(alg) + + # Handle MPO * MPS case + if !is_mps1 && is_mps2 + # MPO * MPS: use ITensorMPS.contract + # Only pass nsweeps for fit algorithm, otherwise remove it from kwargs + if alg == Algorithm"fit"() + result = ITensorMPS.contract(M1_, M2_; alg=alg, cutoff=cutoff, maxdim=maxdim, nsweeps=nsweeps, kwargs...) + else + # Remove nsweeps from kwargs for non-fit algorithms + kwargs_dict = Dict(kwargs) + delete!(kwargs_dict, :nsweeps) + result = ITensorMPS.contract(M1_, M2_; alg=alg, cutoff=cutoff, maxdim=maxdim, kwargs_dict...) + end + return TensorTrain(result) + elseif is_mps1 && !is_mps2 + # MPS * MPO: use ITensorMPS.contract (commutative) + # Only pass nsweeps for fit algorithm, otherwise remove it from kwargs + if alg == Algorithm"fit"() + result = ITensorMPS.contract(M2_, M1_; alg=alg, cutoff=cutoff, maxdim=maxdim, nsweeps=nsweeps, kwargs...) + else + # Remove nsweeps from kwargs for non-fit algorithms + kwargs_dict = Dict(kwargs) + delete!(kwargs_dict, :nsweeps) + result = ITensorMPS.contract(M2_, M1_; alg=alg, cutoff=cutoff, maxdim=maxdim, kwargs_dict...) + end + return TensorTrain(result) + else + # MPO * MPO: use T4AITensorCompat algorithms if alg == Algorithm"densitymatrix"() return TensorTrain(ContractionImpl.contract_densitymatrix(M1_, M2_; cutoff, maxdim, kwargs...)) elseif alg == Algorithm"fit"() @@ -81,11 +121,12 @@ function contract(M1::TensorTrain, M2::TensorTrain; alg=Algorithm"fit"(), cutoff return TensorTrain(ITensorMPS.contract(M1_, M2_; alg=Algorithm"naive"(), cutoff, maxdim, kwargs...)) else error("Unknown algorithm: $alg") + end end end """ - fit(input_states::AbstractVector{TensorTrain}, init::TensorTrain; coeffs::AbstractVector{<:Number}=ones(Int, length(input_states)), kwargs...) + fit(input_states::AbstractVector{TensorTrain}, init::TensorTrain; coeffs::AbstractVector{<:Number}=ones(Int, length(input_states)), cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), nsweeps::Int=default_nsweeps(), kwargs...) Fit a linear combination of multiple TensorTrain objects to approximate their weighted sum. @@ -112,9 +153,10 @@ bond dimensions while maintaining numerical accuracy. # Fit a weighted sum of three tensor trains result = fit([tt1, tt2, tt3], init_tt; coeffs=[1.0, 2.0, 0.5]) ``` -""" -#== + +# Note FIXME (HS): I observed this function is sometime less accurate than direct sum of the input states. +""" function fit( input_states::AbstractVector{TensorTrain}, init::TensorTrain; @@ -124,6 +166,7 @@ function fit( nsweeps::Int = default_nsweeps(), kwargs..., )::TensorTrain + println(stderr, "⚠ Warning: The `fit` function may produce less accurate results than direct sum of the input states. Consider using direct sum (`+`) if accuracy is critical.") # Convert TensorTrain objects to ITensorMPS.MPS mps_inputs = [ITensorMPS.MPS(tt) for tt in input_states] mps_init = ITensorMPS.MPS(init) @@ -134,4 +177,76 @@ function fit( # Convert back to TensorTrain return TensorTrain(mps_result, init.llim, init.rlim) end -==# \ No newline at end of file + +""" + product(A::TensorTrain, Ψ::TensorTrain; alg=Algorithm"fit"(), cutoff=default_cutoff(), maxdim=default_maxdim(), nsweeps=default_nsweeps(), kwargs...) + +Multiply an MPO with an MPS or MPO (official API name). + +This function multiplies an MPO `A` by a tensor train `Ψ` (MPS or MPO), +using `contract` internally and adjusting prime levels for compatibility. + +# Arguments +- `A::TensorTrain`: The MPO to apply +- `Ψ::TensorTrain`: The MPS or MPO to multiply + +# Keywords +- `alg`: Algorithm specification (String or Algorithm type). Defaults to "fit". +- `cutoff::Real`: Truncation cutoff. Defaults to `default_cutoff()`. +- `maxdim::Int`: Maximum bond dimension. Defaults to `default_maxdim()`. +- `nsweeps::Int`: Number of sweeps for variational algorithms. Defaults to `default_nsweeps()`. +- `kwargs...`: Additional keyword arguments passed to contract +""" +function product(A::TensorTrain, Ψ::TensorTrain; alg=Algorithm"fit"(), cutoff=default_cutoff(), maxdim=default_maxdim(), nsweeps=default_nsweeps(), kwargs...) + if :algorithm ∈ keys(kwargs) + error("keyword argument :algorithm is not allowed") + end + + # Convert alg to Algorithm type (accepts both String and Algorithm) + alg_ = alg isa Algorithm ? alg : Algorithm(alg) + + # Warn if cutoff is too small for densitymatrix algorithm + if alg_ == Algorithm("densitymatrix") && cutoff <= 1e-10 + @warn "cutoff is too small for densitymatrix algorithm. Use fit algorithm instead." + end + + # Detect MPS-like vs MPO-like by counting physical indices per site: + # MPS tensors have 1 physical index per site, MPO tensors have 2 per site. + # This is robust to boundary tensors having fewer link indices. + is_mps_like = begin + sites_per_tensor = siteinds(Ψ) + length(Ψ) > 0 && all(length(s) == 1 for s in sites_per_tensor) + end + + if is_mps_like + # Apply MPO to MPS: contract(A, ψ) then replaceprime(..., 1 => 0) + # Use T4AITensorCompat.contract for MPO * MPS + result_tt = contract(A, Ψ; alg=alg_, cutoff=cutoff, maxdim=maxdim, nsweeps=nsweeps, kwargs...) + # Adjust prime levels: replaceprime(..., 1 => 0) to get unprimed result + # Use ITensorMPS.replaceprime by converting to MPS + result_mps = ITensorMPS.MPS(result_tt) + result_mps = ITensorMPS.replaceprime(result_mps, 1 => 0) + return TensorTrain(result_mps) + else + # Apply MPO to MPO: contract(A', B) then replaceprime(..., 2 => 1) + # Use T4AITensorCompat.contract for MPO * MPO (with A' to contract over one set of indices) + A_primed = ITensors.prime(A) + result_tt = contract(A_primed, Ψ; alg=alg_, cutoff=cutoff, maxdim=maxdim, nsweeps=nsweeps, kwargs...) + # Adjust prime levels: replaceprime(..., 2 => 1) to get pairs of primed/unprimed indices + # Use ITensorMPS.replaceprime by converting to MPO + result_mpo = ITensorMPS.MPO(result_tt) + result_mpo = ITensorMPS.replaceprime(result_mpo, 2 => 1) + return TensorTrain(result_mpo) + end +end + +""" + apply(A::TensorTrain, Ψ::TensorTrain; kwargs...) + +Backwards-compatible alias for [`product`](@ref). + +This function multiplies an MPO `A` by a tensor train `Ψ` (MPS or MPO) and +forwards all keyword arguments to [`product`](@ref). +See [`product`](@ref) for the full list of supported keywords and behavior. +""" +apply(A::TensorTrain, Ψ::TensorTrain; kwargs...) = product(A, Ψ; kwargs...) \ No newline at end of file diff --git a/src/ext/T4AITensorCompatChainRulesCoreExt.jl b/src/ext/T4AITensorCompatChainRulesCoreExt.jl new file mode 100644 index 0000000..17c7169 --- /dev/null +++ b/src/ext/T4AITensorCompatChainRulesCoreExt.jl @@ -0,0 +1,21 @@ +# Extension module to resolve ambiguities with ChainRulesCore and InitialValues +# This module is automatically loaded when ChainRulesCore or InitialValues are available + +module T4AITensorCompatChainRulesCoreExt + +using ChainRulesCore: NoTangent, AbstractThunk, NotImplemented, ZeroTangent +using InitialValues: NonspecificInitialValue, SpecificInitialValue +using ..T4AITensorCompat: TensorTrain + +# Resolve ambiguities with ChainRulesCore types +Base.:+(stt::TensorTrain, ::NoTangent) = error("Cannot add ChainRulesCore.NoTangent to TensorTrain") +Base.:+(stt::TensorTrain, ::AbstractThunk) = error("Cannot add ChainRulesCore.AbstractThunk to TensorTrain") +Base.:+(stt::TensorTrain, ::NotImplemented) = error("Cannot add ChainRulesCore.NotImplemented to TensorTrain") +Base.:+(stt::TensorTrain, ::ZeroTangent) = error("Cannot add ChainRulesCore.ZeroTangent to TensorTrain") + +# Resolve ambiguities with InitialValues types +Base.:+(stt::TensorTrain, ::Union{NonspecificInitialValue, SpecificInitialValue{typeof(+)}}) = + error("Cannot add InitialValues to TensorTrain") + +end + diff --git a/src/itensormps_compat.jl b/src/itensormps_compat.jl deleted file mode 100644 index 7a6af95..0000000 --- a/src/itensormps_compat.jl +++ /dev/null @@ -1,167 +0,0 @@ -#=== -itensormps_compat.jl - Compatibility functions for ITensorMPS API - -This file provides compatibility functions that match ITensorMPS API -for use with TensorTrain. These are temporary functions for migration purposes. - -Copyright (c) 2025 Hiroshi Shinaoka and contributors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -====# - -""" - linkinds(tt::TensorTrain) - -Extract the link (bond) indices from a TensorTrain. - -This function returns a vector of link indices connecting adjacent tensors -in the tensor train. For a TensorTrain of length N, it returns N-1 link indices. - -# Arguments -- `tt::TensorTrain`: The tensor train to extract link indices from - -# Returns -- `Vector{Index}`: Vector of link indices connecting adjacent tensors -""" -function linkinds(tt::TensorTrain) - N = length(tt) - if N <= 1 - return Index[] - end - links = Index[] - for n in 1:(N - 1) - # Link index is the common index between tensor n and n+1 - common = commoninds(tt[n], tt[n + 1]) - if length(common) != 1 - error("Expected exactly one common index between tensors $n and $(n+1), got $(length(common))") - end - push!(links, only(common)) - end - return links -end - -""" - linkind(tt::TensorTrain, p::Int) - -Get the link index at position p in a TensorTrain. - -Position p refers to the link between tensor p and p+1. -Valid positions are 1 to length(tt)-1. - -# Arguments -- `tt::TensorTrain`: The tensor train -- `p::Int`: Position of the link (1-indexed, between tensor p and p+1) - -# Returns -- `Index`: The link index at position p -""" -function linkind(tt::TensorTrain, p::Int) - links = linkinds(tt) - if p < 1 || p > length(links) - error("Link position $p out of range. Valid range: 1 to $(length(links))") - end - return links[p] -end - -""" - findsite(tt::TensorTrain, site::Index) - -Find the position of a site index in a TensorTrain. - -This function searches for the site index in the tensor train and returns -the position (1-indexed) where it is found. Returns `nothing` if not found. - -# Arguments -- `tt::TensorTrain`: The tensor train to search -- `site::Index`: The site index to find - -# Returns -- `Union{Int, Nothing}`: Position of the site index, or `nothing` if not found -""" -function findsite(tt::TensorTrain, site::Index) - sites = siteinds(tt) - for (pos, site_vec) in enumerate(sites) - if site in site_vec - return pos - end - end - return nothing -end - -""" - findsites(tt::TensorTrain, site::Index) - -Find all positions of a site index in a TensorTrain. - -This function searches for the site index in the tensor train and returns -all positions (1-indexed) where it is found. - -# Arguments -- `tt::TensorTrain`: The tensor train to search -- `site::Index`: The site index to find - -# Returns -- `Vector{Int}`: Vector of positions where the site index is found -""" -function findsites(tt::TensorTrain, site::Index) - sites = siteinds(tt) - positions = Int[] - for (pos, site_vec) in enumerate(sites) - if site in site_vec - push!(positions, pos) - end - end - return positions -end - -""" - isortho(tt::TensorTrain) - -Check if a TensorTrain is orthogonal (canonical form). - -This function checks whether the tensor train is in orthogonal/canonical form -by delegating to ITensorMPS.isortho after converting to MPS. - -# Arguments -- `tt::TensorTrain`: The tensor train to check - -# Returns -- `Bool`: `true` if the tensor train is orthogonal, `false` otherwise -""" -function isortho(tt::TensorTrain) - mps = ITensorMPS.MPS(tt) - return ITensorMPS.isortho(mps) -end - -""" - orthocenter(tt::TensorTrain) - -Get the orthogonality center position of a TensorTrain. - -This function returns the position of the orthogonality center in the tensor train -by delegating to ITensorMPS.orthocenter after converting to MPS. - -# Arguments -- `tt::TensorTrain`: The tensor train - -# Returns -- `Int`: Position of the orthogonality center (1-indexed) -""" -function orthocenter(tt::TensorTrain) - mps = ITensorMPS.MPS(tt) - return ITensorMPS.orthocenter(mps) -end - -# Note: MPS and MPO are type aliases for TensorTrain, so MPS([...]) and MPO([...]) -# will automatically call TensorTrain([...]) constructor. No explicit constructors needed. - diff --git a/src/tensortrain.jl b/src/tensortrain.jl index c7311f5..9638c74 100644 --- a/src/tensortrain.jl +++ b/src/tensortrain.jl @@ -39,6 +39,9 @@ mutable struct TensorTrain data::Vector{ITensor} llim::Int rlim::Int + + # Internal constructor to prevent ambiguous matches + TensorTrain(data::Vector{ITensor}, llim::Int, rlim::Int) = new(data, llim, rlim) end """ @@ -58,6 +61,245 @@ function TensorTrain(data::Vector{ITensor}) return TensorTrain(data, 0, length(data) + 1) end +# Size-based constructor to allocate an empty tensor train of length `n`. +# Elements are uninitialized ITensors and should be assigned by the caller. +function TensorTrain(n::Int) + return TensorTrain(Vector{ITensor}(undef, n), 0, n + 1) +end + +""" + TensorTrain(sites::Vector{<:Index}) + +Construct a TensorTrain (MPS) from a vector of site indices. + +This creates an MPS with empty ITensors initialized with the specified site indices. + +# Arguments +- `sites::Vector{<:Index}`: Vector of site indices + +# Returns +- `TensorTrain`: A new TensorTrain (MPS) object +""" +function TensorTrain(sites::Vector{<:Index}) + # Construct an MPS with default link dimensions + mps = ITensorMPS.MPS(Float64, sites) + return TensorTrain(mps) +end + +""" + TensorTrain(::Type{T}, sites::Vector{<:Index}; linkdims=1) where {T<:Number} + +Construct a TensorTrain (MPS) from a vector of site indices with specified element type. + +This creates an MPS with empty ITensors of type `T` initialized with the specified site indices. + +# Arguments +- `T::Type{<:Number}`: Element type (e.g., Float64, ComplexF64) +- `sites::Vector{<:Index}`: Vector of site indices +- `linkdims::Union{Integer,Vector{<:Integer}}`: Link dimension(s) (default: 1) + +# Returns +- `TensorTrain`: A new TensorTrain (MPS) object +""" +function TensorTrain(::Type{T}, sites::Vector{<:Index}; linkdims::Union{Integer,Vector{<:Integer}}=1) where {T<:Number} + mps = ITensorMPS.MPS(T, sites; linkdims=linkdims) + return TensorTrain(mps) +end + +""" + TensorTrain(::Type{T}, sites::Vector{Vector{<:Index}}, linkdims::Union{Integer,Vector{<:Integer}}) where {T<:Number} + +Construct a TensorTrain (MPO) from a vector of site index pairs with specified element type and link dimensions. + +This creates an MPO by manually constructing ITensor objects for each site. Each site is represented by a pair of indices +(typically the upper and lower indices of the MPO). + +# Arguments +- `T::Type{<:Number}`: Element type (e.g., Float64, ComplexF64) +- `sites::Vector{Vector{<:Index}}`: Vector of site index pairs, where each element is a vector of indices for that site +- `linkdims::Union{Integer,Vector{<:Integer}}`: Link dimension(s) between sites + +# Returns +- `TensorTrain`: A new TensorTrain (MPO) object + +# Examples +```julia +sites = [[Index(2, "Site,n=\$n"), Index(2, "Site,n=\$n")] for n=1:5] +mpo = TensorTrain(ComplexF64, sites, 2) +``` +""" +function TensorTrain(::Type{T}, sites::AbstractVector{<:AbstractVector{<:Index}}, linkdims::Union{Integer,Vector{<:Integer}}) where {T<:Number} + N = length(sites) + N == 0 && return TensorTrain(Vector{ITensor}()) + + # Normalize linkdims to a vector + if linkdims isa Integer + _linkdims = fill(linkdims, N - 1) + else + _linkdims = linkdims + end + + length(_linkdims) == N - 1 || error("Length mismatch: linkdims ($(length(_linkdims))) must have length $(N - 1)") + + # Create internal link indices (no boundary links) + links = N > 1 ? [Index(_linkdims[n], "Link,l=$n") for n in 1:(N - 1)] : Index[] + + # Create ITensors for each site + tensors = Vector{ITensor}(undef, N) + for n in 1:N + site_inds = sites[n] + if length(site_inds) != 2 + error("Each site must have exactly 2 indices (upper and lower), but site $n has $(length(site_inds))") + end + + # For MPO: sites[n] = [lower_index, upper_index] + # Lower index is unprimed, upper index is primed + lower_ind = site_inds[1] # unprimed + upper_ind = ITensors.prime(site_inds[2]) # primed + + # MPO structure: + # - First site: (upper_site, lower_site, right_link) + # - Last site: (left_link, upper_site, lower_site) + # - Middle: (left_link, upper_site, lower_site, right_link) + if n == 1 && n == N + inds_tuple = (upper_ind, lower_ind) + elseif n == 1 + inds_tuple = (upper_ind, lower_ind, links[n]) + elseif n == N + inds_tuple = (links[n - 1], upper_ind, lower_ind) + else + inds_tuple = (links[n - 1], upper_ind, lower_ind, links[n]) + end + + # Create zero ITensor with appropriate dimensions + dims = map(ITensors.dim, inds_tuple) + data = zeros(T, dims...) + tensors[n] = ITensor(data, inds_tuple...) + end + + return TensorTrain(tensors) +end + +""" + TensorTrain(A::ITensor, sites; kwargs...) + +Construct a TensorTrain from an ITensor by decomposing it according to site indices. + +This function creates an MPS or MPO by decomposing the ITensor `A` site by site +according to the site indices `sites`. The `sites` can be either `Vector{Index}` +(for MPS) or `Vector{Vector{Index}}` (for MPO). + +# Arguments +- `A::ITensor`: The ITensor to decompose +- `sites`: Site indices - either `Vector{Index}` (for MPS) or `Vector{Vector{Index}}` (for MPO) + +# Keyword Arguments +- `leftinds=nothing`: Optional left dangling indices +- `orthocenter::Integer=length(sites)`: Desired orthogonality center +- `tags`: Tags for link indices +- `cutoff`: Truncation error at each link +- `maxdim`: Maximum link dimension +- `kwargs...`: Additional keyword arguments passed to the decomposition + +# Returns +- `TensorTrain`: A new TensorTrain object (MPS or MPO depending on `sites`) + +# Examples +```julia +sites = [Index(2, "Site,n=\$n") for n=1:5] +A = randomITensor(sites...) +tt = TensorTrain(A, sites) +``` +""" +function TensorTrain(A::ITensor, sites; kwargs...) + # Detect if sites is Vector{Index} (MPS) or Vector{Vector{Index}} (MPO) + if length(sites) > 0 && sites[1] isa Index + # MPS case: sites is Vector{Index} + mps = ITensorMPS.MPS(A, sites; kwargs...) + return TensorTrain(mps) + else + # MPO case: sites is Vector{Vector{Index}} or similar + mpo = ITensorMPS.MPO(A, sites; kwargs...) + return TensorTrain(mpo) + end +end + +""" + TensorTrain(A::AbstractArray, sites; kwargs...) + +Construct a TensorTrain from an AbstractArray by converting it to an ITensor first. + +# Arguments +- `A::AbstractArray`: The array to convert +- `sites`: Site indices - either `Vector{Index}` (for MPS) or `Vector{Vector{Index}}` (for MPO) +- `kwargs...`: Keyword arguments passed to the decomposition + +# Returns +- `TensorTrain`: A new TensorTrain object +""" +function TensorTrain(A::AbstractArray, sites; kwargs...) + # Convert array to ITensor + if length(sites) > 0 && sites[1] isa Index + # MPS case: sites is Vector{Index} + A_itensor = ITensor(A, sites...) + mps = ITensorMPS.MPS(A_itensor, sites; kwargs...) + return TensorTrain(mps) + else + # MPO case: sites is Vector{Vector{Index}} + # Flatten sites for ITensor construction + sites_flat = collect(Iterators.flatten(sites)) + A_itensor = ITensor(A, sites_flat...) + mpo = ITensorMPS.MPO(A_itensor, sites; kwargs...) + return TensorTrain(mpo) + end +end + +""" + TensorTrain(tt_data; sites, kwargs...) + +Construct a TensorTrain from tensor train data (like QuanticsTCI/TensorCrossInterpolation) with specified sites. + +This constructor is used for converting from other tensor train formats (e.g., QuanticsTCI, TensorCrossInterpolation) +to TensorTrain by specifying the site indices. + +# Arguments +- `tt_data`: Tensor train data (e.g., from QuanticsTCI or TensorCrossInterpolation) +- `sites`: Site indices - either `Vector{Index}` (for MPS) or `Vector{Vector{Index}}` (for MPO) + +# Keyword Arguments +- `kwargs...`: Additional keyword arguments passed to ITensorMPS constructors + +# Returns +- `TensorTrain`: A new TensorTrain object +""" +function TensorTrain(tt_data; sites, kwargs...) + if length(sites) > 0 && sites[1] isa Index + # MPS case: sites is Vector{Index} + mps = ITensorMPS.MPS(tt_data, sites; kwargs...) + return TensorTrain(mps) + else + # MPO case: sites is Vector{Vector{Index}} + # Check if tt_data is TensorCrossInterpolation.TensorTrain + # Try to use TensorCrossInterpolation conversion if available + tt_type = typeof(tt_data) + if hasmethod(ITensorMPS.MPO, (typeof(tt_data),); kwargs...) + # Try ITensorMPS.MPO with sites keyword argument (for TensorCrossInterpolation) + try + mpo = ITensorMPS.MPO(tt_data; sites=sites, kwargs...) + return TensorTrain(mpo) + catch + # Fallback to generic MPO constructor + mpo = ITensorMPS.MPO(tt_data, sites; kwargs...) + return TensorTrain(mpo) + end + else + # Try generic MPO constructor + mpo = ITensorMPS.MPO(tt_data, sites; kwargs...) + return TensorTrain(mpo) + end + end +end + # Iterator implementation Base.iterate(stt::TensorTrain) = iterate(stt.data) Base.iterate(stt::TensorTrain, state) = iterate(stt.data, state) @@ -178,17 +420,11 @@ function Base.:+(stt1::TensorTrain, stt2::TensorTrain) end """ -Add multiple TensorTrain objects using ITensors.Algorithm("directsum") - -This function computes the sum of multiple tensor trains by: -1. Converting all TensorTrain objects to ITensorMPS.MPS -2. Computing the sum using ITensors.Algorithm("directsum") for high precision -3. Converting the result back to TensorTrain +Add multiple TensorTrain objects using ITensors.Algorithm("directsum"). -The result preserves the tensor structure while combining all bond dimensions. -Uses Algorithm("directsum") instead of default + operator for better numerical precision. +The sum is computed with Algorithm("directsum") for high precision. """ -function Base.:+(stt1::TensorTrain, stt2::TensorTrain, stts...) +function Base.:+(stt1::TensorTrain, stt2::TensorTrain, stts::Vararg{TensorTrain}) # Check that all tensor trains have the same length lengths = [length(stt.data) for stt in [stt1, stt2, stts...]] if !all(l -> l == lengths[1], lengths) @@ -395,7 +631,7 @@ This function computes the difference of multiple tensor trains by: The result preserves the tensor structure while combining all bond dimensions. Uses Algorithm("directsum") for optimal numerical precision. """ -function Base.:-(stt1::TensorTrain, stt2::TensorTrain, stts...) +function Base.:-(stt1::TensorTrain, stt2::TensorTrain, stts::Vararg{TensorTrain}) # Check that all tensor trains have the same length lengths = [length(stt.data) for stt in [stt1, stt2, stts...]] if !all(l -> l == lengths[1], lengths) @@ -427,7 +663,7 @@ This function computes the sum of tensor trains using ITensors.Algorithm("direct Note: Algorithm parameter is accepted for interface compatibility but Algorithm("directsum") is always used for optimal numerical precision. """ -function Base.:+(alg::Algorithm, stt1::TensorTrain, stts...) +function Base.:+(alg::Algorithm, stt1::TensorTrain, stts::Vararg{TensorTrain}) # Check that all tensor trains have the same length lengths = [length(stt.data) for stt in [stt1, stts...]] if !all(l -> l == lengths[1], lengths) @@ -447,6 +683,17 @@ function Base.:+(alg::Algorithm, stt1::TensorTrain, stts...) return TensorTrain(mps_sum, stt1.llim, stt1.rlim) end +""" +Add TensorTrain objects with algorithm keyword argument. + +This function accepts an `alg` keyword argument for interface compatibility, +but always uses Algorithm("directsum") for optimal numerical precision. +""" +function Base.:+(stt1::TensorTrain, stts::Vararg{TensorTrain}; alg::Union{String,Algorithm}="directsum", kwargs...) + alg_obj = alg isa String ? Algorithm(alg) : alg + return +(alg_obj, stt1, stts...) +end + """ truncate!(stt::TensorTrain; cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), kwargs...) @@ -521,7 +768,10 @@ function maxlinkdim(stt::TensorTrain) end function _extractsite(x::TensorTrain, n::Int)::Vector{Index} - if n == 1 + if length(x) == 1 + # Single site: all indices are site indices + return collect(inds(x[n])) + elseif n == 1 return copy(uniqueinds(x[n], x[n + 1])) elseif n == length(x) return copy(uniqueinds(x[n], x[n - 1])) @@ -544,4 +794,419 @@ the site (physical) indices for the corresponding tensor in the train. # Returns - `Vector{Vector{Index}}`: Vector of site index vectors, one per tensor """ -siteinds(x::TensorTrain) = [_extractsite(x, n) for n in eachindex(x)] \ No newline at end of file +siteinds(x::TensorTrain) = [_extractsite(x, n) for n in eachindex(x)] + +""" + prime(tt::TensorTrain, args...; kwargs...) + +Apply `ITensors.prime` to all ITensors in a TensorTrain. + +This function applies the `prime` function to each tensor in the tensor train, +returning a new TensorTrain with primed indices. + +# Arguments +- `tt::TensorTrain`: The tensor train to prime +- `args...`: Arguments passed to `ITensors.prime` +- `kwargs...`: Keyword arguments passed to `ITensors.prime` + +# Returns +- `TensorTrain`: A new TensorTrain with primed indices + +# Examples +```julia +tt_primed = prime(tt, 1) # Prime all indices by 1 +tt_primed = prime(tt, 1; inds=sites) # Prime only specific indices +``` +""" +function ITensors.prime(tt::TensorTrain, args...; kwargs...) + return TensorTrain([ITensors.prime(t, args...; kwargs...) for t in tt.data], tt.llim, tt.rlim) +end + +""" + noprime(tt::TensorTrain, args...; kwargs...) + +Apply `ITensors.noprime` to all ITensors in a TensorTrain. + +This function applies the `noprime` function to each tensor in the tensor train, +returning a new TensorTrain with unprimed indices. + +# Arguments +- `tt::TensorTrain`: The tensor train to unprime +- `args...`: Arguments passed to `ITensors.noprime` +- `kwargs...`: Keyword arguments passed to `ITensors.noprime` + +# Returns +- `TensorTrain`: A new TensorTrain with unprimed indices +""" +function ITensors.noprime(tt::TensorTrain, args...; kwargs...) + return TensorTrain([ITensors.noprime(t, args...; kwargs...) for t in tt.data], tt.llim, tt.rlim) +end + +""" + replaceprime(tt::TensorTrain, p1 => p2; kwargs...) + +Apply `ITensors.replaceprime` to all ITensors in a TensorTrain. + +This function applies the `replaceprime` function to each tensor in the tensor train, +replacing prime level `p1` with `p2` for all matching indices. + +# Arguments +- `tt::TensorTrain`: The tensor train to modify +- `p1 => p2`: Pair specifying the prime level replacement (e.g., `1 => 0` or `2 => 1`) +- `kwargs...`: Keyword arguments passed to `ITensors.replaceprime` + +# Returns +- `TensorTrain`: A new TensorTrain with replaced prime levels + +# Examples +```julia +tt_replaced = replaceprime(tt, 1 => 0) # Replace prime level 1 with 0 +tt_replaced = replaceprime(tt, 2 => 1; inds=sites) # Replace only specific indices +``` +""" +function ITensors.replaceprime(tt::TensorTrain, p1_p2::Pair; kwargs...) + return TensorTrain([ITensors.replaceprime(t, p1_p2; kwargs...) for t in tt.data], tt.llim, tt.rlim) +end + +""" + linkinds(tt::TensorTrain) + +Extract the link (bond) indices from a TensorTrain. + +This function returns a vector of link indices connecting adjacent tensors +in the tensor train. For a TensorTrain of length N, it returns N-1 link indices. + +# Arguments +- `tt::TensorTrain`: The tensor train to extract link indices from + +# Returns +- `Vector{Index}`: Vector of link indices connecting adjacent tensors +""" +function linkinds(tt::TensorTrain) + N = length(tt) + if N <= 1 + return Index[] + end + links = Index[] + for n in 1:(N - 1) + # Link index is the common index between tensor n and n+1 + common = commoninds(tt[n], tt[n + 1]) + if length(common) != 1 + error("Expected exactly one common index between tensors $n and $(n+1), got $(length(common))") + end + push!(links, only(common)) + end + return links +end + +""" + linkind(tt::TensorTrain, p::Int) + +Get the link index at position p in a TensorTrain. + +Position p refers to the link between tensor p and p+1. +Valid positions are 1 to length(tt)-1. + +# Arguments +- `tt::TensorTrain`: The tensor train +- `p::Int`: Position of the link (1-indexed, between tensor p and p+1) + +# Returns +- `Index`: The link index at position p +""" +function linkind(tt::TensorTrain, p::Int) + links = linkinds(tt) + if p < 1 || p > length(links) + error("Link position $p out of range. Valid range: 1 to $(length(links))") + end + return links[p] +end + +""" + findsite(tt::TensorTrain, site::Index) + +Find the position of a site index in a TensorTrain. + +This function searches for the site index in the tensor train and returns +the position (1-indexed) where it is found. Returns `nothing` if not found. + +# Arguments +- `tt::TensorTrain`: The tensor train to search +- `site::Index`: The site index to find + +# Returns +- `Union{Int, Nothing}`: Position of the site index, or `nothing` if not found +""" +function findsite(tt::TensorTrain, site::Index) + sites = siteinds(tt) + for (pos, site_vec) in enumerate(sites) + if site in site_vec + return pos + end + end + return nothing +end + +""" + findsites(tt::TensorTrain, site::Index) + +Find all positions of a site index in a TensorTrain. + +This function searches for the site index in the tensor train and returns +all positions (1-indexed) where it is found. + +# Arguments +- `tt::TensorTrain`: The tensor train to search +- `site::Index`: The site index to find + +# Returns +- `Vector{Int}`: Vector of positions where the site index is found +""" +function findsites(tt::TensorTrain, site::Index) + sites = siteinds(tt) + positions = Int[] + for (pos, site_vec) in enumerate(sites) + if site in site_vec + push!(positions, pos) + end + end + return positions +end + +""" + isortho(tt::TensorTrain) + +Check if a TensorTrain is orthogonal (canonical form). + +This function checks whether the tensor train is in orthogonal/canonical form +by delegating to ITensorMPS.isortho after converting to MPS. + +# Arguments +- `tt::TensorTrain`: The tensor train to check + +# Returns +- `Bool`: `true` if the tensor train is orthogonal, `false` otherwise +""" +function isortho(tt::TensorTrain) + mps = ITensorMPS.MPS(tt) + return ITensorMPS.isortho(mps) +end + +""" + orthocenter(tt::TensorTrain) + +Get the orthogonality center position of a TensorTrain. + +This function returns the position of the orthogonality center in the tensor train +by delegating to ITensorMPS.orthocenter after converting to MPS. + +# Arguments +- `tt::TensorTrain`: The tensor train + +# Returns +- `Int`: Position of the orthogonality center (1-indexed) +""" +function orthocenter(tt::TensorTrain) + mps = ITensorMPS.MPS(tt) + return ITensorMPS.orthocenter(mps) +end + +#=== +Random tensor train generation functions +====# + +import Random +import ITensorMPS + +""" + random_mps(sites::Vector{<:Index}; linkdims=1) + +Construct a random TensorTrain (MPS) with link dimension `linkdims` which by +default has element type `Float64`. + +`linkdims` can also accept a `Vector{Int}` with +`length(linkdims) == length(sites) - 1` for constructing an +MPS with non-uniform bond dimension. + +# Arguments +- `sites::Vector{<:Index}`: Vector of site indices +- `linkdims::Union{Integer,Vector{<:Integer}}`: Link dimension(s) (default: 1) + +# Returns +- `TensorTrain`: A random tensor train (MPS) + +# Examples +```julia +sites = [Index(2, "Qubit,n=\$n") for n = 1:5] +psi = random_mps(sites; linkdims=3) +``` +""" +function random_mps(sites::Vector{<:Index}; linkdims::Union{Integer,Vector{<:Integer}}=1) + return random_mps(Random.default_rng(), sites; linkdims) +end + +""" + random_mps(rng::Random.AbstractRNG, sites::Vector{<:Index}; linkdims=1) + +Construct a random TensorTrain (MPS) with link dimension `linkdims` using the specified RNG. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator +- `sites::Vector{<:Index}`: Vector of site indices +- `linkdims::Union{Integer,Vector{<:Integer}}`: Link dimension(s) (default: 1) + +# Returns +- `TensorTrain`: A random tensor train (MPS) +""" +function random_mps(rng::Random.AbstractRNG, sites::Vector{<:Index}; linkdims::Union{Integer,Vector{<:Integer}}=1) + return random_mps(rng, Float64, sites; linkdims) +end + +""" + random_mps(eltype::Type{<:Number}, sites::Vector{<:Index}; linkdims=1) + +Construct a random TensorTrain (MPS) with specified element type and link dimension. + +# Arguments +- `eltype::Type{<:Number}`: Element type (e.g., Float64, ComplexF64) +- `sites::Vector{<:Index}`: Vector of site indices +- `linkdims::Union{Integer,Vector{<:Integer}}`: Link dimension(s) (default: 1) + +# Returns +- `TensorTrain`: A random tensor train (MPS) +""" +function random_mps(eltype::Type{<:Number}, sites::Vector{<:Index}; linkdims::Union{Integer,Vector{<:Integer}}=1) + return random_mps(Random.default_rng(), eltype, sites; linkdims) +end + +""" + random_mps(rng::Random.AbstractRNG, eltype::Type{<:Number}, sites::Vector{<:Index}; linkdims=1) + +Construct a random TensorTrain (MPS) with specified RNG, element type, and link dimension. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator +- `eltype::Type{<:Number}`: Element type (e.g., Float64, ComplexF64) +- `sites::Vector{<:Index}`: Vector of site indices +- `linkdims::Union{Integer,Vector{<:Integer}}`: Link dimension(s) (default: 1) + +# Returns +- `TensorTrain`: A random tensor train (MPS) +""" +function random_mps(rng::Random.AbstractRNG, eltype::Type{<:Number}, sites::Vector{<:Index}; linkdims::Union{Integer,Vector{<:Integer}}=1) + mps = ITensorMPS.random_mps(rng, eltype, sites; linkdims) + return TensorTrain(mps) +end + +""" + random_mps(sites::Vector{<:Index}, state; linkdims=1) + +Construct a random TensorTrain (MPS) with initial state (for quantum number conservation). + +# Arguments +- `sites::Vector{<:Index}`: Vector of site indices +- `state`: Initial state specification (function or vector) +- `linkdims::Union{Integer,Vector{<:Integer}}`: Link dimension(s) (default: 1) + +# Returns +- `TensorTrain`: A random tensor train (MPS) +""" +function random_mps(sites::Vector{<:Index}, state; linkdims::Union{Integer,Vector{<:Integer}}=1) + return random_mps(Random.default_rng(), sites, state; linkdims) +end + +""" + random_mps(rng::Random.AbstractRNG, sites::Vector{<:Index}, state; linkdims=1) + +Construct a random TensorTrain (MPS) with initial state using the specified RNG. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator +- `sites::Vector{<:Index}`: Vector of site indices +- `state`: Initial state specification (function or vector) +- `linkdims::Union{Integer,Vector{<:Integer}}`: Link dimension(s) (default: 1) + +# Returns +- `TensorTrain`: A random tensor train (MPS) +""" +function random_mps(rng::Random.AbstractRNG, sites::Vector{<:Index}, state; linkdims::Union{Integer,Vector{<:Integer}}=1) + return random_mps(rng, Float64, sites, state; linkdims) +end + +""" + random_mps(eltype::Type{<:Number}, sites::Vector{<:Index}, state; linkdims=1) + +Construct a random TensorTrain (MPS) with element type and initial state. + +# Arguments +- `eltype::Type{<:Number}`: Element type (e.g., Float64, ComplexF64) +- `sites::Vector{<:Index}`: Vector of site indices +- `state`: Initial state specification (function or vector) +- `linkdims::Union{Integer,Vector{<:Integer}}`: Link dimension(s) (default: 1) + +# Returns +- `TensorTrain`: A random tensor train (MPS) +""" +function random_mps(eltype::Type{<:Number}, sites::Vector{<:Index}, state; linkdims::Union{Integer,Vector{<:Integer}}=1) + return random_mps(Random.default_rng(), eltype, sites, state; linkdims) +end + +""" + random_mps(rng::Random.AbstractRNG, eltype::Type{<:Number}, sites::Vector{<:Index}, state; linkdims=1) + +Construct a random TensorTrain (MPS) with RNG, element type, and initial state. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator +- `eltype::Type{<:Number}`: Element type (e.g., Float64, ComplexF64) +- `sites::Vector{<:Index}`: Vector of site indices +- `state`: Initial state specification (function or vector) +- `linkdims::Union{Integer,Vector{<:Integer}}`: Link dimension(s) (default: 1) + +# Returns +- `TensorTrain`: A random tensor train (MPS) +""" +function random_mps(rng::Random.AbstractRNG, eltype::Type{<:Number}, sites::Vector{<:Index}, state; linkdims::Union{Integer,Vector{<:Integer}}=1) + mps = ITensorMPS.random_mps(rng, eltype, sites, state; linkdims) + return TensorTrain(mps) +end + +""" + random_mpo(sites::Vector{<:Index}, m::Int=1) + +Construct a random TensorTrain (MPO) with specified sites. + +# Arguments +- `sites::Vector{<:Index}`: Vector of site indices +- `m::Int`: Currently only m=1 is supported (default: 1) + +# Returns +- `TensorTrain`: A random tensor train (MPO) + +# Examples +```julia +sites = [Index(2, "Qubit,n=\$n") for n = 1:5] +mpo = random_mpo(sites) +``` +""" +function random_mpo(sites::Vector{<:Index}, m::Int=1) + return random_mpo(Random.default_rng(), sites, m) +end + +""" + random_mpo(rng::Random.AbstractRNG, sites::Vector{<:Index}, m::Int=1) + +Construct a random TensorTrain (MPO) with specified RNG and sites. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator +- `sites::Vector{<:Index}`: Vector of site indices +- `m::Int`: Currently only m=1 is supported (default: 1) + +# Returns +- `TensorTrain`: A random tensor train (MPO) +""" +function random_mpo(rng::Random.AbstractRNG, sites::Vector{<:Index}, m::Int=1) + mpo = ITensorMPS.random_mpo(rng, sites, m) + return TensorTrain(mpo) +end \ No newline at end of file diff --git a/test/contraction_test.jl b/test/contraction_test.jl deleted file mode 100644 index 52951e5..0000000 --- a/test/contraction_test.jl +++ /dev/null @@ -1,113 +0,0 @@ -@testitem "contraction.jl" begin - include("util.jl") - - import T4AITensorCompat: TensorTrain, contract, dist - import ITensors: ITensors, ITensor, Index, random_itensor - import ITensorMPS - import ITensors: Algorithm, @Algorithm_str - import LinearAlgebra: norm - ITensors.disable_warn_order() - using Random - - # Test algorithms - algs = ["densitymatrix", "fit", "zipup"] - eps = Dict("densitymatrix" => 1e-6, "fit" => 1e-12, "zipup" => 1e-12) - linkdims = 3 - R = 5 - - for alg in algs - @testset "MPO-MPO contraction (x-y-z) with $alg" begin - Random.seed!(1234) - - sitesx = [Index(2, "Qubit,x=$n") for n = 1:R] - sitesy = [Index(2, "Qubit,y=$n") for n = 1:R] - sitesz = [Index(2, "Qubit,z=$n") for n = 1:R] - - sitesa = collect(collect.(zip(sitesx, sitesy))) - sitesb = collect(collect.(zip(sitesy, sitesz))) - a_mpo = _random_mpo(sitesa; linkdims = linkdims) - b_mpo = _random_mpo(sitesb; linkdims = linkdims) - - # Convert to TensorTrain - a = TensorTrain(a_mpo) - b = TensorTrain(b_mpo) - - ab_ref = contract(a, b; alg = Algorithm"naive"()) - ab = contract(a, b; alg = Algorithm(alg)) - @test relative_error(ab_ref, ab) < eps[alg] - end - end - - for alg in algs - @testset "MPO-MPO contraction (xk-y-z) with $alg" begin - Random.seed!(1234) - sitesx = [Index(2, "Qubit,x=$n") for n = 1:R] - sitesk = [Index(2, "Qubit,k=$n") for n = 1:R] - sitesy = [Index(2, "Qubit,y=$n") for n = 1:R] - sitesz = [Index(2, "Qubit,z=$n") for n = 1:R] - - sitesa = collect(collect.(zip(sitesx, sitesk, sitesy))) - sitesb = collect(collect.(zip(sitesy, sitesz))) - a_mpo = _random_mpo(sitesa; linkdims = linkdims) - b_mpo = _random_mpo(sitesb; linkdims = linkdims) - - # Convert to TensorTrain - a = TensorTrain(a_mpo) - b = TensorTrain(b_mpo) - - ab_ref = contract(a, b; alg = Algorithm"naive"()) - ab = contract(a, b; alg = Algorithm(alg)) - @test relative_error(ab_ref, ab) < eps[alg] - end - end - - for alg in algs - @testset "MPO-MPO contraction (xk-y-zl) with $alg" begin - Random.seed!(1234) - sitesx = [Index(2, "Qubit,x=$n") for n = 1:R] - sitesk = [Index(2, "Qubit,k=$n") for n = 1:R] - sitesy = [Index(2, "Qubit,y=$n") for n = 1:R] - sitesz = [Index(2, "Qubit,z=$n") for n = 1:R] - sitesl = [Index(2, "Qubit,l=$n") for n = 1:R] - - sitesa = collect(collect.(zip(sitesx, sitesk, sitesy))) - sitesb = collect(collect.(zip(sitesy, sitesz, sitesl))) - a_mpo = _random_mpo(sitesa; linkdims = linkdims) - b_mpo = _random_mpo(sitesb; linkdims = linkdims) - - # Convert to TensorTrain - a = TensorTrain(a_mpo) - b = TensorTrain(b_mpo) - - ab_ref = contract(a, b; alg = Algorithm"naive"()) - ab = contract(a, b; alg = Algorithm(alg)) - - @test relative_error(ab_ref, ab) < eps[alg] - end - end - - for alg in algs - @testset "MPO-MPO contraction (xk-ym-zl) with $alg" begin - Random.seed!(1234) - sitesx = [Index(2, "Qubit,x=$n") for n = 1:R] - sitesk = [Index(2, "Qubit,k=$n") for n = 1:R] - sitesy = [Index(2, "Qubit,y=$n") for n = 1:R] - sitesz = [Index(2, "Qubit,z=$n") for n = 1:R] - sitesl = [Index(2, "Qubit,l=$n") for n = 1:R] - sitesm = [Index(2, "Qubit,m=$n") for n = 1:R] - - sitesa = collect(collect.(zip(sitesx, sitesk, sitesm, sitesy))) - sitesb = collect(collect.(zip(sitesy, sitesm, sitesz, sitesl))) - a_mpo = _random_mpo(sitesa; linkdims = linkdims) - b_mpo = _random_mpo(sitesb; linkdims = linkdims) - - # Convert to TensorTrain - a = TensorTrain(a_mpo) - b = TensorTrain(b_mpo) - - ab_ref = contract(a, b; alg = Algorithm"naive"()) - ab = contract(a, b; alg = Algorithm(alg)) - @test relative_error(ab_ref, ab) < eps[alg] - end - end -end \ No newline at end of file diff --git a/test/contraction_tests.jl b/test/contraction_tests.jl new file mode 100644 index 0000000..31b7653 --- /dev/null +++ b/test/contraction_tests.jl @@ -0,0 +1,270 @@ +@testitem "contraction.jl" begin + include("util.jl") + + import T4AITensorCompat: TensorTrain, contract, dist, fit + import ITensors: ITensors, ITensor, Index, random_itensor + import ITensorMPS + import ITensors: Algorithm, @Algorithm_str + import LinearAlgebra: norm + ITensors.disable_warn_order() + using Random + + # Test algorithms + algs = ["densitymatrix", "fit", "zipup"] + eps = Dict("densitymatrix" => 1e-6, "fit" => 1e-12, "zipup" => 1e-12) + linkdims = 3 + R = 5 + + for alg in algs + @testset "MPO-MPO contraction (x-y-z) with $alg" begin + Random.seed!(1234) + + sitesx = [Index(2, "Qubit,x=$n") for n = 1:R] + sitesy = [Index(2, "Qubit,y=$n") for n = 1:R] + sitesz = [Index(2, "Qubit,z=$n") for n = 1:R] + + sitesa = collect(collect.(zip(sitesx, sitesy))) + sitesb = collect(collect.(zip(sitesy, sitesz))) + a_mpo = _random_mpo(sitesa; linkdims = linkdims) + b_mpo = _random_mpo(sitesb; linkdims = linkdims) + + # Convert to TensorTrain + a = TensorTrain(a_mpo) + b = TensorTrain(b_mpo) + + ab_ref = contract(a, b; alg = Algorithm"naive"()) + ab = contract(a, b; alg = Algorithm(alg)) + @test relative_error(ab_ref, ab) < eps[alg] + end + end + + for alg in algs + @testset "MPO-MPO contraction (xk-y-z) with $alg" begin + Random.seed!(1234) + sitesx = [Index(2, "Qubit,x=$n") for n = 1:R] + sitesk = [Index(2, "Qubit,k=$n") for n = 1:R] + sitesy = [Index(2, "Qubit,y=$n") for n = 1:R] + sitesz = [Index(2, "Qubit,z=$n") for n = 1:R] + + sitesa = collect(collect.(zip(sitesx, sitesk, sitesy))) + sitesb = collect(collect.(zip(sitesy, sitesz))) + a_mpo = _random_mpo(sitesa; linkdims = linkdims) + b_mpo = _random_mpo(sitesb; linkdims = linkdims) + + # Convert to TensorTrain + a = TensorTrain(a_mpo) + b = TensorTrain(b_mpo) + + ab_ref = contract(a, b; alg = Algorithm"naive"()) + ab = contract(a, b; alg = Algorithm(alg)) + @test relative_error(ab_ref, ab) < eps[alg] + end + end + + for alg in algs + @testset "MPO-MPO contraction (xk-y-zl) with $alg" begin + Random.seed!(1234) + sitesx = [Index(2, "Qubit,x=$n") for n = 1:R] + sitesk = [Index(2, "Qubit,k=$n") for n = 1:R] + sitesy = [Index(2, "Qubit,y=$n") for n = 1:R] + sitesz = [Index(2, "Qubit,z=$n") for n = 1:R] + sitesl = [Index(2, "Qubit,l=$n") for n = 1:R] + + sitesa = collect(collect.(zip(sitesx, sitesk, sitesy))) + sitesb = collect(collect.(zip(sitesy, sitesz, sitesl))) + a_mpo = _random_mpo(sitesa; linkdims = linkdims) + b_mpo = _random_mpo(sitesb; linkdims = linkdims) + + # Convert to TensorTrain + a = TensorTrain(a_mpo) + b = TensorTrain(b_mpo) + + ab_ref = contract(a, b; alg = Algorithm"naive"()) + ab = contract(a, b; alg = Algorithm(alg)) + + @test relative_error(ab_ref, ab) < eps[alg] + end + end + + for alg in algs + @testset "MPO-MPO contraction (xk-ym-zl) with $alg" begin + Random.seed!(1234) + sitesx = [Index(2, "Qubit,x=$n") for n = 1:R] + sitesk = [Index(2, "Qubit,k=$n") for n = 1:R] + sitesy = [Index(2, "Qubit,y=$n") for n = 1:R] + sitesz = [Index(2, "Qubit,z=$n") for n = 1:R] + sitesl = [Index(2, "Qubit,l=$n") for n = 1:R] + sitesm = [Index(2, "Qubit,m=$n") for n = 1:R] + + sitesa = collect(collect.(zip(sitesx, sitesk, sitesm, sitesy))) + sitesb = collect(collect.(zip(sitesy, sitesm, sitesz, sitesl))) + a_mpo = _random_mpo(sitesa; linkdims = linkdims) + b_mpo = _random_mpo(sitesb; linkdims = linkdims) + + # Convert to TensorTrain + a = TensorTrain(a_mpo) + b = TensorTrain(b_mpo) + + ab_ref = contract(a, b; alg = Algorithm"naive"()) + ab = contract(a, b; alg = Algorithm(alg)) + @test relative_error(ab_ref, ab) < eps[alg] + end + end + + @testset "fit function for summing multiple TensorTrain objects" begin + Random.seed!(1234) + R = 5 + linkdims = 3 + + # Create test sites + sites = [Index(2, "Qubit,s=$n") for n = 1:R] + + # Create multiple random MPS + mps1 = _random_mps(sites; linkdims = linkdims) + mps2 = _random_mps(sites; linkdims = linkdims) + mps3 = _random_mps(sites; linkdims = linkdims) + + # Convert to TensorTrain + tt1 = TensorTrain(mps1) + tt2 = TensorTrain(mps2) + tt3 = TensorTrain(mps3) + + # Create initial guess (use first tensor train) + init_tt = TensorTrain(mps1) + + # Test fit with equal coefficients + coeffs = [1.0, 1.0, 1.0] + result = fit([tt1, tt2, tt3], init_tt; coeffs=coeffs, nsweeps=2, cutoff=1e-12, maxdim=100) + + # Verify result is a TensorTrain + @test result isa TensorTrain + @test length(result) == R + + # Test fit with different coefficients + coeffs2 = [2.0, 0.5, -1.0] + result2 = fit([tt1, tt2, tt3], init_tt; coeffs=coeffs2, nsweeps=2, cutoff=1e-12, maxdim=100) + + @test result2 isa TensorTrain + @test length(result2) == R + + # Test fit with default coefficients (all ones) + result3 = fit([tt1, tt2], init_tt; nsweeps=2, cutoff=1e-12, maxdim=100) + @test result3 isa TensorTrain + @test length(result3) == R + end + + @testset "fit function accuracy" begin + Random.seed!(5678) + R = 4 + linkdims = 2 + + # Create test sites + sites = [Index(2, "Qubit,s=$n") for n = 1:R] + + # Create two random MPS + mps1 = _random_mps(sites; linkdims = linkdims) + mps2 = _random_mps(sites; linkdims = linkdims) + + # Convert to TensorTrain + tt1 = TensorTrain(mps1) + tt2 = TensorTrain(mps2) + + # Compute exact sum using direct sum algorithm + exact_sum = tt1 + tt2 + + # Create initial guess (use exact sum truncated) + init_tt = exact_sum + + # Fit the sum + coeffs = [1.0, 1.0] + fitted = fit([tt1, tt2], init_tt; coeffs=coeffs, nsweeps=3, cutoff=1e-15, maxdim=100) + + # Check that fitted result is close to exact sum + # Note: fit is an approximation, so we check for reasonable accuracy + error = dist(fitted, exact_sum) + @test error < 1e-5 # I observed this function is sometime less accurate than direct sum of the input states. + end +end + +@testitem "contraction_tests.jl/product_MPO_times_MPS" begin + using Test + using ITensors + import ITensors: Algorithm, @Algorithm_str + import ITensorMPS + using T4AITensorCompat: random_mps, random_mpo, product, apply, siteinds, TensorTrain + + ITensors.disable_warn_order() + + @testset "product(A::MPO, ψ::MPS) matches ITensorMPS.contract + replaceprime" for R in (2, 3) + # Build simple qubit sites + sites = [Index(2, "Qubit, s=$(n)") for n in 1:R] + ψ = random_mps(sites) + A = random_mpo(sites) + + # Under test + f = product(A, ψ; alg="naive", cutoff=1e-25) + + # Reference using ITensorMPS.contract + replaceprime + A_mpo = ITensorMPS.MPO(A) + ψ_mps = ITensorMPS.MPS(ψ) + f_ref_mps = ITensorMPS.contract(A_mpo, ψ_mps; alg=Algorithm("naive"), cutoff=1e-25) + f_ref_mps = ITensorMPS.replaceprime(f_ref_mps, 1 => 0) + f_ref = TensorTrain(f_ref_mps) + + # Compare full vectors + s_order = reverse(sites) # standard order used in other tests + f_vec = vec(Array(reduce(*, f), s_order)) + f_ref_vec = vec(Array(reduce(*, f_ref), s_order)) + @test f_vec ≈ f_ref_vec atol=1e-12 rtol=1e-12 + end + + # Alias check + @testset "apply alias equals product (MPO*MPS)" for R in (2, 3) + sites = [Index(2, "Qubit, s=$(n)") for n in 1:R] + ψ = random_mps(sites) + A = random_mpo(sites) + @test Array(reduce(*, apply(A, ψ; alg="naive")), reverse(sites)) ≈ Array(reduce(*, product(A, ψ; alg="naive")), reverse(sites)) + end +end + +@testitem "contraction_tests.jl/product_MPO_times_MPO" begin + using Test + using ITensors + import ITensors: Algorithm, @Algorithm_str + import ITensorMPS + using T4AITensorCompat: random_mpo, product, apply, siteinds, TensorTrain + + ITensors.disable_warn_order() + + @testset "product(A::MPO, B::MPO) matches ITensorMPS.contract(A', B) + replaceprime" for R in (2, 3) + # Build simple qubit sites + sites = [Index(2, "Qubit, s=$(n)") for n in 1:R] + A = random_mpo(sites) + B = random_mpo(sites) + + # Under test (zipup is standard for MPO*MPO) + C = product(A, B; alg="zipup", cutoff=1e-25) + + # Reference using ITensorMPS.contract(A', B) + replaceprime(2=>1) + A_mpo = ITensorMPS.MPO(A) + B_mpo = ITensorMPS.MPO(B) + C_ref_mpo = ITensorMPS.contract(A_mpo', B_mpo; alg=Algorithm("zipup")) + C_ref_mpo = ITensorMPS.replaceprime(C_ref_mpo, 2 => 1) + C_ref = TensorTrain(C_ref_mpo) + + # Compare dense tensors of the MPOs in a consistent index order + flatten_sites(x) = collect(Iterators.flatten(siteinds(x))) + C_arr = Array(reduce(*, C), flatten_sites(C)) + C_ref_arr = Array(reduce(*, C_ref), flatten_sites(C_ref)) + @test C_arr ≈ C_ref_arr atol=1e-12 rtol=1e-12 + end + + # Alias check + @testset "apply alias equals product (MPO*MPO)" for R in (2, 3) + sites = [Index(2, "Qubit, s=$(n)") for n in 1:R] + A = random_mpo(sites) + B = random_mpo(sites) + flatten_sites(x) = collect(Iterators.flatten(siteinds(x))) + @test Array(reduce(*, apply(A, B; alg="zipup")), flatten_sites(apply(A, B; alg="zipup"))) ≈ Array(reduce(*, product(A, B; alg="zipup")), flatten_sites(product(A, B; alg="zipup"))) + end +end \ No newline at end of file diff --git a/test/tensortrain_test.jl b/test/tensortrain_tests.jl similarity index 86% rename from test/tensortrain_test.jl rename to test/tensortrain_tests.jl index 353054a..045ff84 100644 --- a/test/tensortrain_test.jl +++ b/test/tensortrain_tests.jl @@ -1,8 +1,8 @@ @testitem "tensortrain.jl" begin include("util.jl") - import T4AITensorCompat: TensorTrain, dist - import ITensors: ITensor, Index, random_itensor + import T4AITensorCompat: TensorTrain, dist, siteinds, random_mps, random_mpo + import ITensors: ITensor, Index, random_itensor, dim import ITensorMPS import ITensors: Algorithm, @Algorithm_str import LinearAlgebra: norm @@ -683,4 +683,94 @@ @test relative_error(a_plus_a, a_plus_a_truncated_copy) < 1e-13 end + + @testset "TensorTrain siteinds function - random_mps" begin + # Test siteinds with random_mps + sites = [Index(2, "Site,n=$n") for n in 1:5] + mps = random_mps(sites; linkdims=3) + + # Test siteinds function + sites_extracted = siteinds(mps) + + # Check that we get the right number of sites + @test length(sites_extracted) == 5 + + # Check that each site contains the correct index + for (i, site) in enumerate(sites_extracted) + @test length(site) == 1 + @test site[1] == sites[i] + end + end + + @testset "TensorTrain siteinds function - random_mpo" begin + # Test siteinds with random_mpo + sites = [Index(2, "Site,n=$n") for n in 1:4] + mpo = random_mpo(sites) + + # Test siteinds function + sites_extracted = siteinds(mpo) + + # Check that we get the right number of sites + @test length(sites_extracted) == 4 + + # Check that each site contains two indices (upper and lower for MPO) + for site in sites_extracted + @test length(site) == 2 + # Check that indices have the same dimension + @test dim(site[1]) == dim(site[2]) + end + end + + @testset "TensorTrain siteinds function - MPO with multiple site indices" begin + # Test siteinds with MPO that has multiple site indices per site + # This simulates the case used in T4AQuantics + sites1 = [Index(2, "x=$n") for n in 1:3] + sites2 = [Index(2, "y=$n") for n in 1:3] + sites = [[x, y] for (x, y) in zip(sites1, sites2)] + + # Create MPO manually + links = [Index(2, "Link,l=$n") for n in 1:2] + t1 = random_itensor(sites[1]..., links[1]) + t2 = random_itensor(links[1], sites[2]..., links[2]) + t3 = random_itensor(links[2], sites[3]...) + mpo = TensorTrain([t1, t2, t3]) + + # Test siteinds function + sites_extracted = siteinds(mpo) + + # Check that we get the right number of sites + @test length(sites_extracted) == 3 + + # Check that each site contains the correct indices + for (i, site) in enumerate(sites_extracted) + @test length(site) == 2 + @test Set(site) == Set(sites[i]) + end + end + + @testset "TensorTrain siteinds function - empty TensorTrain" begin + # Test siteinds with empty TensorTrain + empty_tt = TensorTrain(Vector{ITensor}()) + + # Test siteinds function + sites = siteinds(empty_tt) + + # Check that we get an empty vector + @test length(sites) == 0 + end + + @testset "TensorTrain siteinds function - single site TensorTrain" begin + # Test siteinds with single site TensorTrain + i1 = Index(2, "i1") + t1 = random_itensor(i1) + stt = TensorTrain([t1]) + + # Test siteinds function + sites = siteinds(stt) + + # Check that we get one site + @test length(sites) == 1 + @test length(sites[1]) == 1 + @test sites[1][1] == i1 + end end