Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 83 additions & 30 deletions src/tensortrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,49 +255,102 @@ function TensorTrain(A::AbstractArray, sites; kwargs...)
end

"""
TensorTrain(tt_data; sites, kwargs...)
TensorTrain(tt_data::AbstractVector{<:AbstractArray}; sites)

Construct a TensorTrain from tensor train data (like QuanticsTCI/TensorCrossInterpolation) with specified sites.
Construct a TensorTrain from a vector of core arrays and site indices.

This constructor is used for converting from other tensor train formats (e.g., QuanticsTCI, TensorCrossInterpolation)
to TensorTrain by specifying the site indices.
This constructor supports sites with arbitrary numbers of physical indices per site.
It validates the number of physical indices at each site, constructs link indices by
reading the core bond dimensions, and builds a `Vector{ITensor}` directly.

# Arguments
- `tt_data`: Tensor train data (e.g., from QuanticsTCI or TensorCrossInterpolation)
- `sites`: Site indices - either `Vector{Index}` (for MPS) or `Vector{Vector{Index}}` (for MPO)
- `tt_data::AbstractVector{<:AbstractArray}`: Vector of TT core arrays.
- `sites`: Either `AbstractVector{<:Index}` (MPS case, one physical index per site) or
`AbstractVector{<:AbstractVector{<:Index}}` (general case, arbitrary number of physical indices per site).
Each element `sites[n]` is either a single `Index` (MPS) or a vector of physical indices for site `n`.

# Keyword Arguments
- `kwargs...`: Additional keyword arguments passed to ITensorMPS constructors
# Assumptions
- Core axis order is `(left_link, physical..., right_link)`.
- Each core has `length(sites[n]) + 2` dimensions (physical indices + 2 links).
- Boundary link dimensions are represented in the arrays (typically 1 on edges).

# Returns
- `TensorTrain`: A new TensorTrain object
"""
function TensorTrain(tt_data; sites, kwargs...)
function TensorTrain(tt_data::AbstractVector{<:AbstractArray}; sites)
N = length(tt_data)
N == 0 && return TensorTrain(Vector{ITensor}())

# Normalize sites to AbstractVector{AbstractVector{Index}} format
# Check if sites is AbstractVector{<:Index} (MPS case)
if length(sites) > 0 && sites[1] isa Index
# MPS case: sites is Vector{Index}
mps = ITensorMPS.MPS(tt_data, sites; kwargs...)
return TensorTrain(mps)
else
# MPO case: sites is Vector{Vector{Index}}
# Check if tt_data is TensorCrossInterpolation.TensorTrain
# Try to use TensorCrossInterpolation conversion if available
tt_type = typeof(tt_data)
if hasmethod(ITensorMPS.MPO, (typeof(tt_data),); kwargs...)
# Try ITensorMPS.MPO with sites keyword argument (for TensorCrossInterpolation)
try
mpo = ITensorMPS.MPO(tt_data; sites=sites, kwargs...)
return TensorTrain(mpo)
catch
# Fallback to generic MPO constructor
mpo = ITensorMPS.MPO(tt_data, sites; kwargs...)
return TensorTrain(mpo)
end
# Convert to Vector{Vector{Index}} format
sites = [[s] for s in sites]
elseif length(sites) > 0 && !(sites[1] isa AbstractVector)
error("sites must be either AbstractVector{<:Index} or AbstractVector{<:AbstractVector{<:Index}}")
end

length(sites) == N || error("Length mismatch: length(sites)=$(length(sites)) must equal length(tt_data)=$N")

# Determine expected number of physical indices per site
expected_physical_per_site = [length(sites[n]) for n in 1:N]

# Validate core dimensionalities and extract link dimensions
left_link_dims = Vector{Int}(undef, N)
right_link_dims = Vector{Int}(undef, N)
for n in 1:N
core = tt_data[n]
nd = ndims(core)
expected_nd = expected_physical_per_site[n] + 2
nd == expected_nd || error("Core $n has ndims=$nd but expected $expected_nd (physical=$(expected_physical_per_site[n]) + 2 links)")
left_link_dims[n] = size(core, 1)
right_link_dims[n] = size(core, nd)
if n > 1 && left_link_dims[n] != right_link_dims[n - 1]
error("Bond mismatch between cores $(n-1) and $n: right=$(right_link_dims[n-1]) vs left=$(left_link_dims[n])")
end
end

# Create link indices from link dimensions for the N-1 bonds
links = N > 1 ? [Index(right_link_dims[n], "Link,l=$n") for n in 1:(N - 1)] : Index[]

# Build ITensors per site
tensors = Vector{ITensor}(undef, N)
for n in 1:N
core = tt_data[n]
nphys = expected_physical_per_site[n]
site_inds_n = sites[n]

# Use site indices as given (no special MPO convention)
site_inds_tuple = Tuple(site_inds_n)

# Validate physical dimensions match the array sizes on axes 2:(1+nphys)
for p in 1:nphys
idx = site_inds_tuple[p]
array_dim = size(core, 1 + p)
ITensors.dim(idx) == array_dim || error(
"Physical dim mismatch at site $n, physical $p: array=$(array_dim) vs index=$(ITensors.dim(idx))",
)
end

# Assemble index tuple in the assumed order: (left?, physical..., right?)
inds_tuple = if n == 1 && n == N
# Single-site
site_inds_tuple
elseif n == 1
# First: physical..., right_link
Tuple(vcat(collect(site_inds_tuple), [links[n]]))
elseif n == N
# Last: left_link, physical...
Tuple(vcat([links[n - 1]], collect(site_inds_tuple)))
else
# Try generic MPO constructor
mpo = ITensorMPS.MPO(tt_data, sites; kwargs...)
return TensorTrain(mpo)
# Middle: left_link, physical..., right_link
Tuple(vcat([links[n - 1]], collect(site_inds_tuple), [links[n]]))
end

tensors[n] = ITensor(core, inds_tuple...)
end

return TensorTrain(tensors)
end

# Iterator implementation
Expand Down
204 changes: 203 additions & 1 deletion test/tensortrain_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
include("util.jl")

import T4AITensorCompat: TensorTrain, dist, siteinds, random_mps, random_mpo
import ITensors: ITensor, Index, random_itensor, dim
import ITensors: ITensor, Index, random_itensor, dim, inds
import ITensorMPS
import ITensors: Algorithm, @Algorithm_str
import LinearAlgebra: norm
Expand Down Expand Up @@ -773,4 +773,206 @@
@test length(sites[1]) == 1
@test sites[1][1] == i1
end

@testset "TensorTrain from AbstractVector{<:AbstractArray} - MPS case" begin
# Create site indices
sites = [Index(2, "Site,n=$n") for n in 1:3]

# Create TT cores as arrays: (left_link, physical, right_link)
# First core: (1, 2, 3) - left boundary is 1
# Middle core: (3, 2, 4) - bond dimension 3
# Last core: (4, 2, 1) - right boundary is 1
core1 = randn(Float64, 1, 2, 3)
core2 = randn(Float64, 3, 2, 4)
core3 = randn(Float64, 4, 2, 1)
tt_data = [core1, core2, core3]

# Construct TensorTrain using AbstractVector{<:Index} format
tt = TensorTrain(tt_data; sites=sites)

# Verify structure
@test length(tt) == 3
@test length(tt.data) == 3

# Verify that each tensor has correct indices
# First tensor should have: (site[1], link[1])
@test length(inds(tt[1])) == 2
@test sites[1] in inds(tt[1])

# Middle tensor should have: (link[1], site[2], link[2])
@test length(inds(tt[2])) == 3
@test sites[2] in inds(tt[2])

# Last tensor should have: (link[2], site[3])
@test length(inds(tt[3])) == 2
@test sites[3] in inds(tt[3])

# Verify physical dimensions match
for n in 1:3
site_inds = siteinds(tt)[n]
@test length(site_inds) == 1
@test site_inds[1] == sites[n]
@test dim(site_inds[1]) == 2
end

# Verify link dimensions
links = T4AITensorCompat.linkinds(tt)
@test length(links) == 2
@test dim(links[1]) == 3
@test dim(links[2]) == 4

# Verify that we can convert back to MPS and the data matches
mps = ITensorMPS.MPS(tt)
@test length(mps) == 3

# Verify round-trip: convert back to arrays and check dimensions
tt_reconstructed = TensorTrain(mps)
for n in 1:3
@test length(inds(tt_reconstructed[n])) == length(inds(tt[n]))
end
end

@testset "TensorTrain from AbstractVector{<:AbstractArray} - MPO case" begin
# Create site indices (MPO: each site has 2 physical indices)
sites = [[Index(2, "Site,n=$n,lower"), Index(2, "Site,n=$n,upper")] for n in 1:3]

# Create TT cores as arrays: (left_link, physical1, physical2, right_link)
# First core: (1, 2, 2, 3)
# Middle core: (3, 2, 2, 4)
# Last core: (4, 2, 2, 1)
core1 = randn(Float64, 1, 2, 2, 3)
core2 = randn(Float64, 3, 2, 2, 4)
core3 = randn(Float64, 4, 2, 2, 1)
tt_data = [core1, core2, core3]

# Construct TensorTrain using AbstractVector{AbstractVector{<:Index}} format
tt = TensorTrain(tt_data; sites=sites)

# Verify structure
@test length(tt) == 3
@test length(tt.data) == 3

# Verify that each tensor has correct indices
# First tensor should have: (site[1][1], site[1][2], link[1])
@test length(inds(tt[1])) == 3

# Middle tensor should have: (link[1], site[2][1], site[2][2], link[2])
@test length(inds(tt[2])) == 4

# Last tensor should have: (link[2], site[3][1], site[3][2])
@test length(inds(tt[3])) == 3

# Verify physical dimensions match
for n in 1:3
site_inds = siteinds(tt)[n]
@test length(site_inds) == 2
@test Set(site_inds) == Set(sites[n])
@test dim(site_inds[1]) == 2
@test dim(site_inds[2]) == 2
end

# Verify link dimensions
links = T4AITensorCompat.linkinds(tt)
@test length(links) == 2
@test dim(links[1]) == 3
@test dim(links[2]) == 4

# Verify that we can convert back to MPO and the data matches
mpo = ITensorMPS.MPO(tt)
@test length(mpo) == 3
end

@testset "TensorTrain from AbstractVector{<:AbstractArray} - arbitrary physical indices" begin
# Test case with different numbers of physical indices per site
# Site 1: 1 physical index
# Site 2: 2 physical indices
# Site 3: 3 physical indices
sites = [
[Index(2, "Site1")],
[Index(3, "Site2,lower"), Index(3, "Site2,upper")],
[Index(2, "Site3,a"), Index(2, "Site3,b"), Index(2, "Site3,c")]
]

# Create TT cores with appropriate dimensions
# Core 1: (1, 2, 3) - 1 physical index
# Core 2: (3, 3, 3, 4) - 2 physical indices
# Core 3: (4, 2, 2, 2, 1) - 3 physical indices
core1 = randn(Float64, 1, 2, 3)
core2 = randn(Float64, 3, 3, 3, 4)
core3 = randn(Float64, 4, 2, 2, 2, 1)
tt_data = [core1, core2, core3]

# Construct TensorTrain
tt = TensorTrain(tt_data; sites=sites)

# Verify structure
@test length(tt) == 3

# Verify physical dimensions match
for n in 1:3
site_inds = siteinds(tt)[n]
@test length(site_inds) == length(sites[n])
@test Set(site_inds) == Set(sites[n])
end

# Verify link dimensions
links = T4AITensorCompat.linkinds(tt)
@test length(links) == 2
@test dim(links[1]) == 3
@test dim(links[2]) == 4
end

@testset "TensorTrain from AbstractVector{<:AbstractArray} - error cases" begin
sites = [Index(2, "Site,n=$n") for n in 1:3]

# Test length mismatch
core1 = randn(Float64, 1, 2, 3)
core2 = randn(Float64, 3, 2, 4)
wrong_sites = [Index(2, "Site,n=$n") for n in 1:2] # Wrong length
@test_throws Exception TensorTrain([core1, core2]; sites=wrong_sites)

# Test dimension mismatch - wrong number of dimensions
wrong_core1 = randn(Float64, 1, 2, 3, 4) # 4 dims instead of 3
@test_throws Exception TensorTrain([wrong_core1, core2]; sites=sites)

# Test physical dimension mismatch
wrong_sites_dim = [Index(3, "Site,n=$n") for n in 1:3] # dim=3 instead of 2
@test_throws Exception TensorTrain([core1, core2, randn(Float64, 4, 3, 1)]; sites=wrong_sites_dim)

# Test bond dimension mismatch
wrong_core2 = randn(Float64, 5, 2, 4) # Left link dim=5, but core1 right link dim=3
@test_throws Exception TensorTrain([core1, wrong_core2]; sites=sites[1:2])
end

@testset "TensorTrain from AbstractVector{<:AbstractArray} - single site" begin
# Single site MPS
sites = [Index(2, "Site1")]
core1 = randn(Float64, 1, 2, 1)
tt = TensorTrain([core1]; sites=sites)

@test length(tt) == 1
@test length(inds(tt[1])) == 1
@test sites[1] in inds(tt[1])

# Single site MPO
sites_mpo = [[Index(2, "Site1,lower"), Index(2, "Site1,upper")]]
core1_mpo = randn(Float64, 1, 2, 2, 1)
tt_mpo = TensorTrain([core1_mpo]; sites=sites_mpo)

@test length(tt_mpo) == 1
@test length(inds(tt_mpo[1])) == 2
site_inds = siteinds(tt_mpo)[1]
@test length(site_inds) == 2
@test Set(site_inds) == Set(sites_mpo[1])
end

@testset "TensorTrain from AbstractVector{<:AbstractArray} - empty" begin
# Empty TensorTrain
tt = TensorTrain(Vector{Array{Float64,3}}(); sites=Index[])
@test length(tt) == 0

# Empty TensorTrain with empty sites vector
tt2 = TensorTrain(Vector{Array{Float64,3}}(); sites=Vector{Index}[])
@test length(tt2) == 0
end
end
Loading