diff --git a/src/tensortrain.jl b/src/tensortrain.jl index 9638c74..26cc036 100644 --- a/src/tensortrain.jl +++ b/src/tensortrain.jl @@ -255,49 +255,102 @@ function TensorTrain(A::AbstractArray, sites; kwargs...) end """ - TensorTrain(tt_data; sites, kwargs...) + TensorTrain(tt_data::AbstractVector{<:AbstractArray}; sites) -Construct a TensorTrain from tensor train data (like QuanticsTCI/TensorCrossInterpolation) with specified sites. +Construct a TensorTrain from a vector of core arrays and site indices. -This constructor is used for converting from other tensor train formats (e.g., QuanticsTCI, TensorCrossInterpolation) -to TensorTrain by specifying the site indices. +This constructor supports sites with arbitrary numbers of physical indices per site. +It validates the number of physical indices at each site, constructs link indices by +reading the core bond dimensions, and builds a `Vector{ITensor}` directly. # Arguments -- `tt_data`: Tensor train data (e.g., from QuanticsTCI or TensorCrossInterpolation) -- `sites`: Site indices - either `Vector{Index}` (for MPS) or `Vector{Vector{Index}}` (for MPO) +- `tt_data::AbstractVector{<:AbstractArray}`: Vector of TT core arrays. +- `sites`: Either `AbstractVector{<:Index}` (MPS case, one physical index per site) or + `AbstractVector{<:AbstractVector{<:Index}}` (general case, arbitrary number of physical indices per site). + Each element `sites[n]` is either a single `Index` (MPS) or a vector of physical indices for site `n`. -# Keyword Arguments -- `kwargs...`: Additional keyword arguments passed to ITensorMPS constructors +# Assumptions +- Core axis order is `(left_link, physical..., right_link)`. +- Each core has `length(sites[n]) + 2` dimensions (physical indices + 2 links). +- Boundary link dimensions are represented in the arrays (typically 1 on edges). # Returns - `TensorTrain`: A new TensorTrain object """ -function TensorTrain(tt_data; sites, kwargs...) +function TensorTrain(tt_data::AbstractVector{<:AbstractArray}; sites) + N = length(tt_data) + N == 0 && return TensorTrain(Vector{ITensor}()) + + # Normalize sites to AbstractVector{AbstractVector{Index}} format + # Check if sites is AbstractVector{<:Index} (MPS case) if length(sites) > 0 && sites[1] isa Index - # MPS case: sites is Vector{Index} - mps = ITensorMPS.MPS(tt_data, sites; kwargs...) - return TensorTrain(mps) - else - # MPO case: sites is Vector{Vector{Index}} - # Check if tt_data is TensorCrossInterpolation.TensorTrain - # Try to use TensorCrossInterpolation conversion if available - tt_type = typeof(tt_data) - if hasmethod(ITensorMPS.MPO, (typeof(tt_data),); kwargs...) - # Try ITensorMPS.MPO with sites keyword argument (for TensorCrossInterpolation) - try - mpo = ITensorMPS.MPO(tt_data; sites=sites, kwargs...) - return TensorTrain(mpo) - catch - # Fallback to generic MPO constructor - mpo = ITensorMPS.MPO(tt_data, sites; kwargs...) - return TensorTrain(mpo) - end + # Convert to Vector{Vector{Index}} format + sites = [[s] for s in sites] + elseif length(sites) > 0 && !(sites[1] isa AbstractVector) + error("sites must be either AbstractVector{<:Index} or AbstractVector{<:AbstractVector{<:Index}}") + end + + length(sites) == N || error("Length mismatch: length(sites)=$(length(sites)) must equal length(tt_data)=$N") + + # Determine expected number of physical indices per site + expected_physical_per_site = [length(sites[n]) for n in 1:N] + + # Validate core dimensionalities and extract link dimensions + left_link_dims = Vector{Int}(undef, N) + right_link_dims = Vector{Int}(undef, N) + for n in 1:N + core = tt_data[n] + nd = ndims(core) + expected_nd = expected_physical_per_site[n] + 2 + nd == expected_nd || error("Core $n has ndims=$nd but expected $expected_nd (physical=$(expected_physical_per_site[n]) + 2 links)") + left_link_dims[n] = size(core, 1) + right_link_dims[n] = size(core, nd) + if n > 1 && left_link_dims[n] != right_link_dims[n - 1] + error("Bond mismatch between cores $(n-1) and $n: right=$(right_link_dims[n-1]) vs left=$(left_link_dims[n])") + end + end + + # Create link indices from link dimensions for the N-1 bonds + links = N > 1 ? [Index(right_link_dims[n], "Link,l=$n") for n in 1:(N - 1)] : Index[] + + # Build ITensors per site + tensors = Vector{ITensor}(undef, N) + for n in 1:N + core = tt_data[n] + nphys = expected_physical_per_site[n] + site_inds_n = sites[n] + + # Use site indices as given (no special MPO convention) + site_inds_tuple = Tuple(site_inds_n) + + # Validate physical dimensions match the array sizes on axes 2:(1+nphys) + for p in 1:nphys + idx = site_inds_tuple[p] + array_dim = size(core, 1 + p) + ITensors.dim(idx) == array_dim || error( + "Physical dim mismatch at site $n, physical $p: array=$(array_dim) vs index=$(ITensors.dim(idx))", + ) + end + + # Assemble index tuple in the assumed order: (left?, physical..., right?) + inds_tuple = if n == 1 && n == N + # Single-site + site_inds_tuple + elseif n == 1 + # First: physical..., right_link + Tuple(vcat(collect(site_inds_tuple), [links[n]])) + elseif n == N + # Last: left_link, physical... + Tuple(vcat([links[n - 1]], collect(site_inds_tuple))) else - # Try generic MPO constructor - mpo = ITensorMPS.MPO(tt_data, sites; kwargs...) - return TensorTrain(mpo) + # Middle: left_link, physical..., right_link + Tuple(vcat([links[n - 1]], collect(site_inds_tuple), [links[n]])) end + + tensors[n] = ITensor(core, inds_tuple...) end + + return TensorTrain(tensors) end # Iterator implementation diff --git a/test/tensortrain_tests.jl b/test/tensortrain_tests.jl index 045ff84..233621a 100644 --- a/test/tensortrain_tests.jl +++ b/test/tensortrain_tests.jl @@ -2,7 +2,7 @@ include("util.jl") import T4AITensorCompat: TensorTrain, dist, siteinds, random_mps, random_mpo - import ITensors: ITensor, Index, random_itensor, dim + import ITensors: ITensor, Index, random_itensor, dim, inds import ITensorMPS import ITensors: Algorithm, @Algorithm_str import LinearAlgebra: norm @@ -773,4 +773,206 @@ @test length(sites[1]) == 1 @test sites[1][1] == i1 end + + @testset "TensorTrain from AbstractVector{<:AbstractArray} - MPS case" begin + # Create site indices + sites = [Index(2, "Site,n=$n") for n in 1:3] + + # Create TT cores as arrays: (left_link, physical, right_link) + # First core: (1, 2, 3) - left boundary is 1 + # Middle core: (3, 2, 4) - bond dimension 3 + # Last core: (4, 2, 1) - right boundary is 1 + core1 = randn(Float64, 1, 2, 3) + core2 = randn(Float64, 3, 2, 4) + core3 = randn(Float64, 4, 2, 1) + tt_data = [core1, core2, core3] + + # Construct TensorTrain using AbstractVector{<:Index} format + tt = TensorTrain(tt_data; sites=sites) + + # Verify structure + @test length(tt) == 3 + @test length(tt.data) == 3 + + # Verify that each tensor has correct indices + # First tensor should have: (site[1], link[1]) + @test length(inds(tt[1])) == 2 + @test sites[1] in inds(tt[1]) + + # Middle tensor should have: (link[1], site[2], link[2]) + @test length(inds(tt[2])) == 3 + @test sites[2] in inds(tt[2]) + + # Last tensor should have: (link[2], site[3]) + @test length(inds(tt[3])) == 2 + @test sites[3] in inds(tt[3]) + + # Verify physical dimensions match + for n in 1:3 + site_inds = siteinds(tt)[n] + @test length(site_inds) == 1 + @test site_inds[1] == sites[n] + @test dim(site_inds[1]) == 2 + end + + # Verify link dimensions + links = T4AITensorCompat.linkinds(tt) + @test length(links) == 2 + @test dim(links[1]) == 3 + @test dim(links[2]) == 4 + + # Verify that we can convert back to MPS and the data matches + mps = ITensorMPS.MPS(tt) + @test length(mps) == 3 + + # Verify round-trip: convert back to arrays and check dimensions + tt_reconstructed = TensorTrain(mps) + for n in 1:3 + @test length(inds(tt_reconstructed[n])) == length(inds(tt[n])) + end + end + + @testset "TensorTrain from AbstractVector{<:AbstractArray} - MPO case" begin + # Create site indices (MPO: each site has 2 physical indices) + sites = [[Index(2, "Site,n=$n,lower"), Index(2, "Site,n=$n,upper")] for n in 1:3] + + # Create TT cores as arrays: (left_link, physical1, physical2, right_link) + # First core: (1, 2, 2, 3) + # Middle core: (3, 2, 2, 4) + # Last core: (4, 2, 2, 1) + core1 = randn(Float64, 1, 2, 2, 3) + core2 = randn(Float64, 3, 2, 2, 4) + core3 = randn(Float64, 4, 2, 2, 1) + tt_data = [core1, core2, core3] + + # Construct TensorTrain using AbstractVector{AbstractVector{<:Index}} format + tt = TensorTrain(tt_data; sites=sites) + + # Verify structure + @test length(tt) == 3 + @test length(tt.data) == 3 + + # Verify that each tensor has correct indices + # First tensor should have: (site[1][1], site[1][2], link[1]) + @test length(inds(tt[1])) == 3 + + # Middle tensor should have: (link[1], site[2][1], site[2][2], link[2]) + @test length(inds(tt[2])) == 4 + + # Last tensor should have: (link[2], site[3][1], site[3][2]) + @test length(inds(tt[3])) == 3 + + # Verify physical dimensions match + for n in 1:3 + site_inds = siteinds(tt)[n] + @test length(site_inds) == 2 + @test Set(site_inds) == Set(sites[n]) + @test dim(site_inds[1]) == 2 + @test dim(site_inds[2]) == 2 + end + + # Verify link dimensions + links = T4AITensorCompat.linkinds(tt) + @test length(links) == 2 + @test dim(links[1]) == 3 + @test dim(links[2]) == 4 + + # Verify that we can convert back to MPO and the data matches + mpo = ITensorMPS.MPO(tt) + @test length(mpo) == 3 + end + + @testset "TensorTrain from AbstractVector{<:AbstractArray} - arbitrary physical indices" begin + # Test case with different numbers of physical indices per site + # Site 1: 1 physical index + # Site 2: 2 physical indices + # Site 3: 3 physical indices + sites = [ + [Index(2, "Site1")], + [Index(3, "Site2,lower"), Index(3, "Site2,upper")], + [Index(2, "Site3,a"), Index(2, "Site3,b"), Index(2, "Site3,c")] + ] + + # Create TT cores with appropriate dimensions + # Core 1: (1, 2, 3) - 1 physical index + # Core 2: (3, 3, 3, 4) - 2 physical indices + # Core 3: (4, 2, 2, 2, 1) - 3 physical indices + core1 = randn(Float64, 1, 2, 3) + core2 = randn(Float64, 3, 3, 3, 4) + core3 = randn(Float64, 4, 2, 2, 2, 1) + tt_data = [core1, core2, core3] + + # Construct TensorTrain + tt = TensorTrain(tt_data; sites=sites) + + # Verify structure + @test length(tt) == 3 + + # Verify physical dimensions match + for n in 1:3 + site_inds = siteinds(tt)[n] + @test length(site_inds) == length(sites[n]) + @test Set(site_inds) == Set(sites[n]) + end + + # Verify link dimensions + links = T4AITensorCompat.linkinds(tt) + @test length(links) == 2 + @test dim(links[1]) == 3 + @test dim(links[2]) == 4 + end + + @testset "TensorTrain from AbstractVector{<:AbstractArray} - error cases" begin + sites = [Index(2, "Site,n=$n") for n in 1:3] + + # Test length mismatch + core1 = randn(Float64, 1, 2, 3) + core2 = randn(Float64, 3, 2, 4) + wrong_sites = [Index(2, "Site,n=$n") for n in 1:2] # Wrong length + @test_throws Exception TensorTrain([core1, core2]; sites=wrong_sites) + + # Test dimension mismatch - wrong number of dimensions + wrong_core1 = randn(Float64, 1, 2, 3, 4) # 4 dims instead of 3 + @test_throws Exception TensorTrain([wrong_core1, core2]; sites=sites) + + # Test physical dimension mismatch + wrong_sites_dim = [Index(3, "Site,n=$n") for n in 1:3] # dim=3 instead of 2 + @test_throws Exception TensorTrain([core1, core2, randn(Float64, 4, 3, 1)]; sites=wrong_sites_dim) + + # Test bond dimension mismatch + wrong_core2 = randn(Float64, 5, 2, 4) # Left link dim=5, but core1 right link dim=3 + @test_throws Exception TensorTrain([core1, wrong_core2]; sites=sites[1:2]) + end + + @testset "TensorTrain from AbstractVector{<:AbstractArray} - single site" begin + # Single site MPS + sites = [Index(2, "Site1")] + core1 = randn(Float64, 1, 2, 1) + tt = TensorTrain([core1]; sites=sites) + + @test length(tt) == 1 + @test length(inds(tt[1])) == 1 + @test sites[1] in inds(tt[1]) + + # Single site MPO + sites_mpo = [[Index(2, "Site1,lower"), Index(2, "Site1,upper")]] + core1_mpo = randn(Float64, 1, 2, 2, 1) + tt_mpo = TensorTrain([core1_mpo]; sites=sites_mpo) + + @test length(tt_mpo) == 1 + @test length(inds(tt_mpo[1])) == 2 + site_inds = siteinds(tt_mpo)[1] + @test length(site_inds) == 2 + @test Set(site_inds) == Set(sites_mpo[1]) + end + + @testset "TensorTrain from AbstractVector{<:AbstractArray} - empty" begin + # Empty TensorTrain + tt = TensorTrain(Vector{Array{Float64,3}}(); sites=Index[]) + @test length(tt) == 0 + + # Empty TensorTrain with empty sites vector + tt2 = TensorTrain(Vector{Array{Float64,3}}(); sites=Vector{Index}[]) + @test length(tt2) == 0 + end end